@@ -310,7 +310,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
310310#endif
311311
312312 // / solve eigenvector and eigenvalue for H(k)
313- this ->hamiltSolvePsiK (pHamilt, psi, precondition, eigenvalues.data () + ik * psi.get_nbands ());
313+ this ->hamiltSolvePsiK (pHamilt, psi, precondition, eigenvalues.data () + ik * psi.get_nbands (), this -> wfc_basis -> nks );
314314
315315 if (skip_charge)
316316 {
@@ -357,7 +357,8 @@ template <typename T, typename Device>
357357void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
358358 psi::Psi<T, Device>& psi,
359359 std::vector<Real>& pre_condition,
360- Real* eigenvalue)
360+ Real* eigenvalue,
361+ const int & nk_nums)
361362{
362363#ifdef __MPI
363364 const diag_comm_info comm_info = {POOL_WORLD, this ->rank_in_pool , this ->nproc_in_pool };
@@ -367,10 +368,16 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
367368
368369 auto ngk_pointer = psi.get_ngk_pointer ();
369370
371+ std::vector<int > ngk_vector (nk_nums, 0 );
372+ for (int i = 0 ; i < nk_nums; i++)
373+ {
374+ ngk_vector[i] = ngk_pointer[i];
375+ }
376+
370377 if (this ->method == " cg" )
371378 {
372379 // wrap the subspace_func into a lambda function
373- auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
380+ auto subspace_func = [hm, ngk_pointer, ngk_vector ](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
374381 // psi_in should be a 2D tensor:
375382 // psi_in.shape() = [nbands, nbasis]
376383 const auto ndim = psi_in.shape ().ndim ();
@@ -380,12 +387,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
380387 1 ,
381388 psi_in.shape ().dim_size (0 ),
382389 psi_in.shape ().dim_size (1 ),
383- nullptr );
390+ ngk_vector );
384391 auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data <T>(),
385392 1 ,
386393 psi_out.shape ().dim_size (0 ),
387394 psi_out.shape ().dim_size (1 ),
388- nullptr );
395+ ngk_vector );
389396 auto eigen = ct::Tensor (ct::DataTypeToEnum<Real>::value,
390397 ct::DeviceType::CpuDevice,
391398 ct::TensorShape ({psi_in.shape ().dim_size (0 )}));
@@ -404,7 +411,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
404411 using ct_Device = typename ct::PsiToContainer<Device>::type;
405412
406413 // wrap the hpsi_func and spsi_func into a lambda function
407- auto hpsi_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
414+ auto hpsi_func = [hm, ngk_pointer, ngk_vector ](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
408415 ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
409416 // psi_in should be a 2D tensor:
410417 // psi_in.shape() = [nbands, nbasis]
@@ -415,7 +422,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
415422 1 ,
416423 ndim == 1 ? 1 : psi_in.shape ().dim_size (0 ),
417424 ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ),
418- nullptr );
425+ ngk_vector );
419426 psi::Range all_bands_range (true , psi_wrapper.get_current_k (), 0 , psi_wrapper.get_nbands () - 1 );
420427 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
421428 hpsi_info info (&psi_wrapper, all_bands_range, hpsi_out.data <T>());
@@ -475,11 +482,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
475482 const int nband = psi.get_nbands ();
476483 const int nbasis = psi.get_nbasis ();
477484 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
478- auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
485+ auto hpsi_func = [hm, ngk_pointer, ngk_vector ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
479486 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
480487
481488 // Convert "pointer data stucture" to a psi::Psi object
482- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, nullptr );
489+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector );
483490
484491 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
485492
@@ -496,11 +503,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
496503 else if (this ->method == " dav_subspace" )
497504 {
498505 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
499- auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
506+ auto hpsi_func = [hm, ngk_pointer, ngk_vector ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
500507 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
501508
502509 // Convert "pointer data stucture" to a psi::Psi object
503- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, nullptr );
510+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector );
504511
505512 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
506513
@@ -547,11 +554,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
547554 // Davidson matrix-blockvector functions
548555 // / wrap hpsi into lambda function, Matrix \times blockvector
549556 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
550- auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
557+ auto hpsi_func = [hm, ngk_pointer, ngk_vector ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
551558 ModuleBase::timer::tick (" David" , " hpsi_func" );
552559
553560 // Convert pointer of psi_in to a psi::Psi object
554- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, nullptr );
561+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector );
555562
556563 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
557564
0 commit comments