@@ -310,7 +310,11 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
310310#endif
311311
312312 // / solve eigenvector and eigenvalue for H(k)
313- this ->hamiltSolvePsiK (pHamilt, psi, precondition, eigenvalues.data () + ik * psi.get_nbands ());
313+ this ->hamiltSolvePsiK (pHamilt,
314+ psi,
315+ precondition,
316+ eigenvalues.data () + ik * psi.get_nbands (),
317+ this ->wfc_basis ->nks );
314318
315319 if (skip_charge)
316320 {
@@ -357,19 +361,28 @@ template <typename T, typename Device>
357361void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
358362 psi::Psi<T, Device>& psi,
359363 std::vector<Real>& pre_condition,
360- Real* eigenvalue)
364+ Real* eigenvalue,
365+ const int & nk_nums)
361366{
362367#ifdef __MPI
363368 const diag_comm_info comm_info = {POOL_WORLD, this ->rank_in_pool , this ->nproc_in_pool };
364369#else
365370 const diag_comm_info comm_info = {this ->rank_in_pool , this ->nproc_in_pool };
366371#endif
367372
373+ auto ngk_pointer = psi.get_ngk_pointer ();
374+
375+ std::vector<int > ngk_vector_temp (nk_nums, 0 );
376+
377+ for (size_t i = 0 ; i < nk_nums; i++)
378+ {
379+ ngk_vector_temp[i] = ngk_pointer[i];
380+ }
381+
368382 if (this ->method == " cg" )
369383 {
370384 // wrap the subspace_func into a lambda function
371- auto ngk_pointer = psi.get_ngk_pointer ();
372- auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
385+ auto subspace_func = [hm, ngk_vector_temp](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
373386 // psi_in should be a 2D tensor:
374387 // psi_in.shape() = [nbands, nbasis]
375388 const auto ndim = psi_in.shape ().ndim ();
@@ -379,12 +392,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
379392 1 ,
380393 psi_in.shape ().dim_size (0 ),
381394 psi_in.shape ().dim_size (1 ),
382- ngk_pointer );
395+ ngk_vector_temp );
383396 auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data <T>(),
384397 1 ,
385398 psi_out.shape ().dim_size (0 ),
386399 psi_out.shape ().dim_size (1 ),
387- ngk_pointer );
400+ ngk_vector_temp );
388401 auto eigen = ct::Tensor (ct::DataTypeToEnum<Real>::value,
389402 ct::DeviceType::CpuDevice,
390403 ct::TensorShape ({psi_in.shape ().dim_size (0 )}));
@@ -403,7 +416,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
403416 using ct_Device = typename ct::PsiToContainer<Device>::type;
404417
405418 // wrap the hpsi_func and spsi_func into a lambda function
406- auto hpsi_func = [hm, ngk_pointer ](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
419+ auto hpsi_func = [hm, ngk_vector_temp ](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
407420 ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
408421 // psi_in should be a 2D tensor:
409422 // psi_in.shape() = [nbands, nbasis]
@@ -414,7 +427,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
414427 1 ,
415428 ndim == 1 ? 1 : psi_in.shape ().dim_size (0 ),
416429 ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ),
417- ngk_pointer );
430+ ngk_vector_temp );
418431 psi::Range all_bands_range (true , psi_wrapper.get_current_k (), 0 , psi_wrapper.get_nbands () - 1 );
419432 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
420433 hpsi_info info (&psi_wrapper, all_bands_range, hpsi_out.data <T>());
@@ -473,13 +486,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
473486 {
474487 const int nband = psi.get_nbands ();
475488 const int nbasis = psi.get_nbasis ();
476- auto ngk_pointer = psi.get_ngk_pointer ();
477489 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
478- auto hpsi_func = [hm, ngk_pointer ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
490+ auto hpsi_func = [hm, ngk_vector_temp ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
479491 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
480492
481493 // Convert "pointer data stucture" to a psi::Psi object
482- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer );
494+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector_temp );
483495
484496 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
485497
@@ -495,13 +507,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
495507 }
496508 else if (this ->method == " dav_subspace" )
497509 {
498- auto ngk_pointer = psi.get_ngk_pointer ();
499510 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
500- auto hpsi_func = [hm, ngk_pointer ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
511+ auto hpsi_func = [hm, ngk_vector_temp ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
501512 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
502513
503514 // Convert "pointer data stucture" to a psi::Psi object
504- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer );
515+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector_temp );
505516
506517 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
507518
@@ -546,15 +557,13 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
546557 const int ld_psi = psi.get_nbasis (); // / leading dimension of psi
547558
548559 // Davidson matrix-blockvector functions
549-
550- auto ngk_pointer = psi.get_ngk_pointer ();
551560 // / wrap hpsi into lambda function, Matrix \times blockvector
552561 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
553- auto hpsi_func = [hm, ngk_pointer ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
562+ auto hpsi_func = [hm, ngk_vector_temp ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
554563 ModuleBase::timer::tick (" David" , " hpsi_func" );
555564
556565 // Convert pointer of psi_in to a psi::Psi object
557- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer );
566+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector_temp );
558567
559568 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
560569
0 commit comments