22
33#ifdef __MPI
44#include " diago_scalapack.h"
5+ #include " module_base/scalapack_connector.h"
56#else
67#include " diago_lapack.h"
78#endif
2425#include " module_elecstate/elecstate_lcao.h"
2526#endif
2627
27- #include " diago_cg.h"
2828#include " module_base/global_variable.h"
2929#include " module_base/memory.h"
30- #include " module_base/scalapack_connector.h"
3130#include " module_base/timer.h"
32- #include " module_hsolver/diago_iter_assist.h"
33- #include " module_hsolver/kernels/math_kernel_op.h"
3431#include " module_hsolver/parallel_k2d.h"
35- #include " module_io/write_HS.h"
3632#include " module_parameter/parameter.h"
3733
38- #include < ATen/core/tensor.h>
39- #include < ATen/core/tensor_map.h>
40- #include < ATen/core/tensor_types.h>
41- #include < unistd.h>
42-
4334namespace hsolver
4435{
4536
@@ -72,19 +63,8 @@ void HSolverLCAO<T, Device>::solve(hamilt::Hamilt<T>* pHamilt,
7263 ModuleBase::timer::tick (" HSolverLCAO" , " solve" );
7364 return ;
7465 }
75- #endif
7666
77- if (this ->method == " cg_in_lcao" )
78- {
79- this ->precondition_lcao .resize (psi.get_nbasis ());
80-
81- using Real = typename GetTypeReal<T>::type;
82- // set precondition
83- for (size_t i = 0 ; i < precondition_lcao.size (); i++)
84- {
85- precondition_lcao[i] = 1.0 ;
86- }
87- }
67+ #endif
8868
8969#ifdef __MPI
9070 if (GlobalV::KPAR_LCAO > 1
@@ -103,7 +83,7 @@ void HSolverLCAO<T, Device>::solve(hamilt::Hamilt<T>* pHamilt,
10383
10484 psi.fix_k (ik);
10585
106- // solve eigenvector and eigenvalue for H(k)
86+ // / solve eigenvector and eigenvalue for H(k)
10787 this ->hamiltSolvePsiK (pHamilt, psi, &(pes->ekb (ik, 0 )));
10888 }
10989 }
@@ -135,6 +115,7 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
135115 sa.diag (hm, psi, eigenvalue);
136116#endif
137117 }
118+
138119#ifdef __ELPA
139120 else if (this ->method == " genelpa" )
140121 {
@@ -147,151 +128,33 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
147128 el.diag (hm, psi, eigenvalue);
148129 }
149130#endif
131+
150132#ifdef __CUDA
151133 else if (this ->method == " cusolver" )
152134 {
153135 DiagoCusolver<T> cs (this ->ParaV );
154136 cs.diag (hm, psi, eigenvalue);
155137 }
138+ #ifdef __CUSOLVERMP
156139 else if (this ->method == " cusolvermp" )
157140 {
158- #ifdef __CUSOLVERMP
159141 DiagoCusolverMP<T> cm;
160142 cm.diag (hm, psi, eigenvalue);
161- #else
162- ModuleBase::WARNING_QUIT (" HSolverLCAO" , " CUSOLVERMP did not compiled!" );
163- #endif
164143 }
165144#endif
166- else if ( this -> method == " lapack " )
167- {
145+ # endif
146+
168147#ifndef __MPI
148+ else if (this ->method == " lapack" ) // only for single core
149+ {
169150 DiagoLapack<T> la;
170151 la.diag (hm, psi, eigenvalue);
171- #else
172- ModuleBase::WARNING_QUIT (" HSolverLCAO::solve" , " This method of DiagH is not supported!" );
173- #endif
174152 }
153+ #endif
154+
175155 else
176156 {
177-
178- using ct_Device = typename ct::PsiToContainer<base_device::DEVICE_CPU>::type;
179-
180- auto subspace_func = [](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
181- // psi_in should be a 2D tensor:
182- // psi_in.shape() = [nbands, nbasis]
183- const auto ndim = psi_in.shape ().ndim ();
184- REQUIRES_OK (ndim == 2 , " dims of psi_in should be less than or equal to 2" );
185- };
186-
187- DiagoCG<T, Device> cg (PARAM.inp .basis_type ,
188- PARAM.inp .calculation ,
189- DiagoIterAssist<T, Device>::need_subspace,
190- subspace_func,
191- DiagoIterAssist<T, Device>::PW_DIAG_THR,
192- DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
193- GlobalV::NPROC_IN_POOL);
194-
195- hamilt::MatrixBlock<T> h_mat, s_mat;
196- hm->matrix (h_mat, s_mat);
197-
198- // set h_mat & s_mat
199- for (int i = 0 ; i < h_mat.row ; i++)
200- {
201- for (int j = i; j < h_mat.col ; j++)
202- {
203- h_mat.p [h_mat.row * j + i] = hsolver::get_conj (h_mat.p [h_mat.row * i + j]);
204- s_mat.p [s_mat.row * j + i] = hsolver::get_conj (s_mat.p [s_mat.row * i + j]);
205- }
206- }
207-
208- const T *one_ = nullptr , *zero_ = nullptr ;
209- one_ = new T (static_cast <T>(1.0 ));
210- zero_ = new T (static_cast <T>(0.0 ));
211-
212- auto hpsi_func = [h_mat, one_, zero_](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
213- ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
214- // psi_in should be a 2D tensor:
215- // psi_in.shape() = [nbands, nbasis]
216- const auto ndim = psi_in.shape ().ndim ();
217- REQUIRES_OK (ndim <= 2 , " dims of psi_in should be less than or equal to 2" );
218-
219- Device* ctx = {};
220-
221- gemv_op<T, Device>()(ctx,
222- ' N' ,
223- h_mat.row ,
224- h_mat.col ,
225- one_,
226- h_mat.p ,
227- h_mat.row ,
228- psi_in.data <T>(),
229- 1 ,
230- zero_,
231- hpsi_out.data <T>(),
232- 1 );
233-
234- ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
235- };
236-
237- auto spsi_func = [s_mat, one_, zero_](const ct::Tensor& psi_in, ct::Tensor& spsi_out) {
238- ModuleBase::timer::tick (" DiagoCG_New" , " spsi_func" );
239- // psi_in should be a 2D tensor:
240- // psi_in.shape() = [nbands, nbasis]
241- const auto ndim = psi_in.shape ().ndim ();
242- REQUIRES_OK (ndim <= 2 , " dims of psi_in should be less than or equal to 2" );
243-
244- Device* ctx = {};
245-
246- gemv_op<T, Device>()(ctx,
247- ' N' ,
248- s_mat.row ,
249- s_mat.col ,
250- one_,
251- s_mat.p ,
252- s_mat.row ,
253- psi_in.data <T>(),
254- 1 ,
255- zero_,
256- spsi_out.data <T>(),
257- 1 );
258-
259- ModuleBase::timer::tick (" DiagoCG_New" , " spsi_func" );
260- };
261-
262- // if (this->is_first_scf)
263- // {
264- for (size_t i = 0 ; i < psi.get_nbands (); i++)
265- {
266- for (size_t j = 0 ; j < psi.get_nbasis (); j++)
267- {
268- psi (i, j) = *zero_;
269- }
270- psi (i, i) = *one_;
271- }
272- // }
273-
274- auto psi_tensor = ct::TensorMap (psi.get_pointer (),
275- ct::DataTypeToEnum<T>::value,
276- ct::DeviceTypeToEnum<ct_Device>::value,
277- ct::TensorShape ({psi.get_nbands (), psi.get_nbasis ()}))
278- .slice ({0 , 0 }, {psi.get_nbands (), psi.get_current_nbas ()});
279-
280- auto eigen_tensor = ct::TensorMap (eigenvalue,
281- ct::DataTypeToEnum<Real>::value,
282- ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
283- ct::TensorShape ({psi.get_nbands ()}));
284-
285- auto prec_tensor = ct::TensorMap (this ->precondition_lcao .data (),
286- ct::DataTypeToEnum<Real>::value,
287- ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
288- ct::TensorShape ({static_cast <int >(this ->precondition_lcao .size ())}))
289- .slice ({0 }, {psi.get_current_nbas ()});
290-
291- cg.diag (hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
292-
293- // TODO: Double check tensormap's potential problem
294- ct::TensorMap (psi.get_pointer (), psi_tensor, {psi.get_nbands (), psi.get_nbasis ()}).sync (psi_tensor);
157+ ModuleBase::WARNING_QUIT (" HSolverLCAO::solve" , " This method is not supported for lcao basis in ABACUS!" );
295158 }
296159
297160 ModuleBase::timer::tick (" HSolverLCAO" , " hamiltSolvePsiK" );
0 commit comments