22
33#ifdef __MPI
44#include " diago_scalapack.h"
5+ #include " module_base/scalapack_connector.h"
56#else
67#include " diago_lapack.h"
78#endif
2425#include " module_elecstate/elecstate_lcao.h"
2526#endif
2627
27- #include " diago_cg.h"
2828#include " module_base/global_variable.h"
2929#include " module_base/memory.h"
30- #include " module_base/scalapack_connector.h"
3130#include " module_base/timer.h"
32- #include " module_hsolver/diago_iter_assist.h"
33- #include " module_hsolver/kernels/math_kernel_op.h"
3431#include " module_hsolver/parallel_k2d.h"
35- #include " module_io/write_HS.h"
3632#include " module_parameter/parameter.h"
3733
38- #include < ATen/core/tensor.h>
39- #include < ATen/core/tensor_map.h>
40- #include < ATen/core/tensor_types.h>
41- #include < unistd.h>
42-
4334namespace hsolver
4435{
4536
@@ -74,51 +65,39 @@ void HSolverLCAO<T, Device>::solve(hamilt::Hamilt<T>* pHamilt,
7465 }
7566#endif
7667
77- if (this ->method == " cg_in_lcao" )
78- {
79- this ->precondition_lcao .resize (psi.get_nbasis ());
80-
81- using Real = typename GetTypeReal<T>::type;
82- // set precondition
83- for (size_t i = 0 ; i < precondition_lcao.size (); i++)
84- {
85- precondition_lcao[i] = 1.0 ;
86- }
87- }
88-
89- #ifdef __MPI
9068 if (GlobalV::KPAR_LCAO > 1
9169 && (this ->method == " genelpa" || this ->method == " elpa" || this ->method == " scalapack_gvx" ))
9270 {
71+ #ifdef __MPI
9372 this ->parakSolve (pHamilt, psi, pes, GlobalV::KPAR_LCAO);
94- }
95- else
9673#endif
74+ }
75+ else if (GlobalV::KPAR_LCAO == 1 )
9776 {
98- // / Loop over k points for solve Hamiltonian to charge density
77+ // / Loop over k points for solve Hamiltonian to eigenpairs(eigenvalues and eigenvectors).
9978 for (int ik = 0 ; ik < psi.get_nk (); ++ik)
10079 {
10180 // / update H(k) for each k point
10281 pHamilt->updateHk (ik);
10382
83+ // / find psi pointer for each k point
10484 psi.fix_k (ik);
10585
106- // solve eigenvector and eigenvalue for H(k)
86+ // / solve eigenvector and eigenvalue for H(k)
10787 this ->hamiltSolvePsiK (pHamilt, psi, &(pes->ekb (ik, 0 )));
10888 }
10989 }
11090
111- if (skip_charge) // used in nscf calculation
91+ if (! skip_charge) // used in scf calculation
11292 {
113- ModuleBase::timer::tick (" HSolverLCAO" , " solve" );
93+ // calculate charge by eigenpairs(eigenvalues and eigenvectors)
94+ pes->psiToRho (psi);
11495 }
115- else // used in scf calculation
96+ else // used in nscf calculation
11697 {
117- // calculate charge by psi
118- pes->psiToRho (psi);
119- ModuleBase::timer::tick (" HSolverLCAO" , " solve" );
12098 }
12199
100+ ModuleBase::timer::tick (" HSolverLCAO" , " solve" );
122101 return ;
123102}
124103
@@ -135,6 +114,7 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
135114 sa.diag (hm, psi, eigenvalue);
136115#endif
137116 }
117+
138118#ifdef __ELPA
139119 else if (this ->method == " genelpa" )
140120 {
@@ -147,151 +127,33 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
147127 el.diag (hm, psi, eigenvalue);
148128 }
149129#endif
130+
150131#ifdef __CUDA
151132 else if (this ->method == " cusolver" )
152133 {
153134 DiagoCusolver<T> cs (this ->ParaV );
154135 cs.diag (hm, psi, eigenvalue);
155136 }
137+ #ifdef __CUSOLVERMP
156138 else if (this ->method == " cusolvermp" )
157139 {
158- #ifdef __CUSOLVERMP
159140 DiagoCusolverMP<T> cm;
160141 cm.diag (hm, psi, eigenvalue);
161- #else
162- ModuleBase::WARNING_QUIT (" HSolverLCAO" , " CUSOLVERMP did not compiled!" );
163- #endif
164142 }
165143#endif
166- else if ( this -> method == " lapack " )
167- {
144+ # endif
145+
168146#ifndef __MPI
147+ else if (this ->method == " lapack" ) // only for single core
148+ {
169149 DiagoLapack<T> la;
170150 la.diag (hm, psi, eigenvalue);
171- #else
172- ModuleBase::WARNING_QUIT (" HSolverLCAO::solve" , " This type of eigensolver is not supported!" );
173- #endif
174151 }
152+ #endif
153+
175154 else
176155 {
177-
178- using ct_Device = typename ct::PsiToContainer<base_device::DEVICE_CPU>::type;
179-
180- auto subspace_func = [](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
181- // psi_in should be a 2D tensor:
182- // psi_in.shape() = [nbands, nbasis]
183- const auto ndim = psi_in.shape ().ndim ();
184- REQUIRES_OK (ndim == 2 , " dims of psi_in should be less than or equal to 2" );
185- };
186-
187- DiagoCG<T, Device> cg (PARAM.inp .basis_type ,
188- PARAM.inp .calculation ,
189- DiagoIterAssist<T, Device>::need_subspace,
190- subspace_func,
191- DiagoIterAssist<T, Device>::PW_DIAG_THR,
192- DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
193- GlobalV::NPROC_IN_POOL);
194-
195- hamilt::MatrixBlock<T> h_mat, s_mat;
196- hm->matrix (h_mat, s_mat);
197-
198- // set h_mat & s_mat
199- for (int i = 0 ; i < h_mat.row ; i++)
200- {
201- for (int j = i; j < h_mat.col ; j++)
202- {
203- h_mat.p [h_mat.row * j + i] = hsolver::get_conj (h_mat.p [h_mat.row * i + j]);
204- s_mat.p [s_mat.row * j + i] = hsolver::get_conj (s_mat.p [s_mat.row * i + j]);
205- }
206- }
207-
208- const T *one_ = nullptr , *zero_ = nullptr ;
209- one_ = new T (static_cast <T>(1.0 ));
210- zero_ = new T (static_cast <T>(0.0 ));
211-
212- auto hpsi_func = [h_mat, one_, zero_](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
213- ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
214- // psi_in should be a 2D tensor:
215- // psi_in.shape() = [nbands, nbasis]
216- const auto ndim = psi_in.shape ().ndim ();
217- REQUIRES_OK (ndim <= 2 , " dims of psi_in should be less than or equal to 2" );
218-
219- Device* ctx = {};
220-
221- gemv_op<T, Device>()(ctx,
222- ' N' ,
223- h_mat.row ,
224- h_mat.col ,
225- one_,
226- h_mat.p ,
227- h_mat.row ,
228- psi_in.data <T>(),
229- 1 ,
230- zero_,
231- hpsi_out.data <T>(),
232- 1 );
233-
234- ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
235- };
236-
237- auto spsi_func = [s_mat, one_, zero_](const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
238- ModuleBase::timer::tick (" DiagoCG_New" , " spsi_func" );
239- // psi_in should be a 2D tensor:
240- // psi_in.shape() = [nbands, nbasis]
241- const auto ndim = psi_in.shape ().ndim ();
242- REQUIRES_OK (ndim <= 2 , " dims of psi_in should be less than or equal to 2" );
243-
244- Device* ctx = {};
245-
246- gemv_op<T, Device>()(ctx,
247- ' N' ,
248- s_mat.row ,
249- s_mat.col ,
250- one_,
251- s_mat.p ,
252- s_mat.row ,
253- psi_in.data <T>(),
254- 1 ,
255- zero_,
256- spsi_out.data <T>(),
257- 1 );
258-
259- ModuleBase::timer::tick (" DiagoCG_New" , " spsi_func" );
260- };
261-
262- // if (this->is_first_scf)
263- // {
264- for (size_t i = 0 ; i < psi.get_nbands (); i++)
265- {
266- for (size_t j = 0 ; j < psi.get_nbasis (); j++)
267- {
268- psi (i, j) = *zero_;
269- }
270- psi (i, i) = *one_;
271- }
272- // }
273-
274- auto psi_tensor = ct::TensorMap (psi.get_pointer (),
275- ct::DataTypeToEnum<T>::value,
276- ct::DeviceTypeToEnum<ct_Device>::value,
277- ct::TensorShape ({psi.get_nbands (), psi.get_nbasis ()}))
278- .slice ({0 , 0 }, {psi.get_nbands (), psi.get_current_nbas ()});
279-
280- auto eigen_tensor = ct::TensorMap (eigenvalue,
281- ct::DataTypeToEnum<Real>::value,
282- ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
283- ct::TensorShape ({psi.get_nbands ()}));
284-
285- auto prec_tensor = ct::TensorMap (this ->precondition_lcao .data (),
286- ct::DataTypeToEnum<Real>::value,
287- ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
288- ct::TensorShape ({static_cast <int >(this ->precondition_lcao .size ())}))
289- .slice ({0 }, {psi.get_current_nbas ()});
290-
291- cg.diag (hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
292-
293- // TODO: Double check tensormap's potential problem
294- ct::TensorMap (psi.get_pointer (), psi_tensor, {psi.get_nbands (), psi.get_nbasis ()}).sync (psi_tensor);
156+ ModuleBase::WARNING_QUIT (" HSolverLCAO::solve" , " This method is not supported for lcao basis in ABACUS!" );
295157 }
296158
297159 ModuleBase::timer::tick (" HSolverLCAO" , " hamiltSolvePsiK" );
0 commit comments