@@ -310,7 +310,11 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
310310#endif
311311
312312 // / solve eigenvector and eigenvalue for H(k)
313- this ->hamiltSolvePsiK (pHamilt, psi, precondition, eigenvalues.data () + ik * psi.get_nbands (), this ->wfc_basis ->nks );
313+ this ->hamiltSolvePsiK (pHamilt,
314+ psi,
315+ precondition,
316+ eigenvalues.data () + ik * psi.get_nbands (),
317+ this ->wfc_basis ->nks );
314318
315319 if (skip_charge)
316320 {
@@ -370,20 +374,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
370374 const diag_comm_info comm_info = {this ->rank_in_pool , this ->nproc_in_pool };
371375#endif
372376
373- auto ngk_pointer = psi.get_ngk_pointer ();
374-
375- std::vector<int > ngk_vector (nk_nums, 0 );
376- for (int i = 0 ; i < nk_nums; i++)
377- {
378- ngk_vector[i] = ngk_pointer[i];
379- }
380-
381377 const int cur_nbasis = psi.get_ngk (psi.get_current_k ());
382378
383379 if (this ->method == " cg" )
384380 {
385381 // wrap the subspace_func into a lambda function
386- auto subspace_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
382+ auto subspace_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
387383 // psi_in should be a 2D tensor:
388384 // psi_in.shape() = [nbands, nbasis]
389385 const auto ndim = psi_in.shape ().ndim ();
@@ -393,13 +389,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
393389 1 ,
394390 psi_in.shape ().dim_size (0 ),
395391 psi_in.shape ().dim_size (1 ),
396- ngk_vector,
397392 cur_nbasis);
398393 auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data <T>(),
399394 1 ,
400395 psi_out.shape ().dim_size (0 ),
401396 psi_out.shape ().dim_size (1 ),
402- ngk_vector,
403397 cur_nbasis);
404398 auto eigen = ct::Tensor (ct::DataTypeToEnum<Real>::value,
405399 ct::DeviceType::CpuDevice,
@@ -419,7 +413,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
419413 using ct_Device = typename ct::PsiToContainer<Device>::type;
420414
421415 // wrap the hpsi_func and spsi_func into a lambda function
422- auto hpsi_func = [hm, ngk_vector, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
416+ auto hpsi_func = [hm, cur_nbasis](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
423417 ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
424418 // psi_in should be a 2D tensor:
425419 // psi_in.shape() = [nbands, nbasis]
@@ -430,7 +424,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
430424 1 ,
431425 ndim == 1 ? 1 : psi_in.shape ().dim_size (0 ),
432426 ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ),
433- ngk_vector,
434427 cur_nbasis);
435428 psi::Range all_bands_range (true , psi_wrapper.get_current_k (), 0 , psi_wrapper.get_nbands () - 1 );
436429 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
@@ -491,11 +484,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
491484 const int nband = psi.get_nbands ();
492485 const int nbasis = psi.get_nbasis ();
493486 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
494- auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
487+ auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
495488 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
496489
497490 // Convert "pointer data stucture" to a psi::Psi object
498- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector, cur_nbasis);
491+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, cur_nbasis);
499492
500493 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
501494
@@ -512,11 +505,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
512505 else if (this ->method == " dav_subspace" )
513506 {
514507 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
515- auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
508+ auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
516509 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
517510
518511 // Convert "pointer data stucture" to a psi::Psi object
519- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector, cur_nbasis);
512+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, cur_nbasis);
520513
521514 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
522515
@@ -557,17 +550,17 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
557550
558551 // dimensions of matrix to be solved
559552 const int dim = psi.get_cur_effective_basis (); // / dimension of matrix
560- const int nband = psi.get_nbands (); // / number of eigenpairs sought
561- const int ld_psi = psi.get_nbasis (); // / leading dimension of psi
553+ const int nband = psi.get_nbands (); // / number of eigenpairs sought
554+ const int ld_psi = psi.get_nbasis (); // / leading dimension of psi
562555
563556 // Davidson matrix-blockvector functions
564557 // / wrap hpsi into lambda function, Matrix \times blockvector
565558 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
566- auto hpsi_func = [hm, ngk_vector, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
559+ auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
567560 ModuleBase::timer::tick (" David" , " hpsi_func" );
568561
569562 // Convert pointer of psi_in to a psi::Psi object
570- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector, cur_nbasis);
563+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, cur_nbasis);
571564
572565 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
573566
0 commit comments