@@ -310,7 +310,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
310310#endif
311311
312312 // / solve eigenvector and eigenvalue for H(k)
313- this ->hamiltSolvePsiK (pHamilt, psi, precondition, eigenvalues.data () + ik * psi.get_nbands ());
313+ this ->hamiltSolvePsiK (pHamilt, psi, precondition, eigenvalues.data () + ik * psi.get_nbands (), this -> wfc_basis -> nks );
314314
315315 if (skip_charge)
316316 {
@@ -361,19 +361,27 @@ template <typename T, typename Device>
361361void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
362362 psi::Psi<T, Device>& psi,
363363 std::vector<Real>& pre_condition,
364- Real* eigenvalue)
364+ Real* eigenvalue,
365+ const int & nk_nums)
365366{
366367#ifdef __MPI
367368 const diag_comm_info comm_info = {POOL_WORLD, this ->rank_in_pool , this ->nproc_in_pool };
368369#else
369370 const diag_comm_info comm_info = {this ->rank_in_pool , this ->nproc_in_pool };
370371#endif
371372
373+ auto ngk_pointer = psi.get_ngk_pointer ();
374+
375+ std::vector<int > ngk_vector (nk_nums, 0 );
376+ for (int i = 0 ; i < nk_nums; i++)
377+ {
378+ ngk_vector[i] = ngk_pointer[i];
379+ }
380+
372381 if (this ->method == " cg" )
373382 {
374383 // wrap the subspace_func into a lambda function
375- auto ngk_pointer = psi.get_ngk_pointer ();
376- auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
384+ auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
377385 // psi_in should be a 2D tensor:
378386 // psi_in.shape() = [nbands, nbasis]
379387 const auto ndim = psi_in.shape ().ndim ();
@@ -383,12 +391,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
383391 1 ,
384392 psi_in.shape ().dim_size (0 ),
385393 psi_in.shape ().dim_size (1 ),
386- ngk_pointer );
394+ ngk_vector );
387395 auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data <T>(),
388396 1 ,
389397 psi_out.shape ().dim_size (0 ),
390398 psi_out.shape ().dim_size (1 ),
391- ngk_pointer );
399+ ngk_vector );
392400 auto eigen = ct::Tensor (ct::DataTypeToEnum<Real>::value,
393401 ct::DeviceType::CpuDevice,
394402 ct::TensorShape ({psi_in.shape ().dim_size (0 )}));
@@ -407,7 +415,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
407415 using ct_Device = typename ct::PsiToContainer<Device>::type;
408416
409417 // wrap the hpsi_func and spsi_func into a lambda function
410- auto hpsi_func = [hm, ngk_pointer ](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
418+ auto hpsi_func = [hm, ngk_vector ](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
411419 ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
412420 // psi_in should be a 2D tensor:
413421 // psi_in.shape() = [nbands, nbasis]
@@ -418,7 +426,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
418426 1 ,
419427 ndim == 1 ? 1 : psi_in.shape ().dim_size (0 ),
420428 ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ),
421- ngk_pointer );
429+ ngk_vector );
422430 psi::Range all_bands_range (true , psi_wrapper.get_current_k (), 0 , psi_wrapper.get_nbands () - 1 );
423431 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
424432 hpsi_info info (&psi_wrapper, all_bands_range, hpsi_out.data <T>());
@@ -477,13 +485,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
477485 {
478486 const int nband = psi.get_nbands ();
479487 const int nbasis = psi.get_nbasis ();
480- auto ngk_pointer = psi.get_ngk_pointer ();
481488 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
482- auto hpsi_func = [hm, ngk_pointer ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
489+ auto hpsi_func = [hm, ngk_vector ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
483490 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
484491
485492 // Convert "pointer data stucture" to a psi::Psi object
486- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer );
493+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector );
487494
488495 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
489496
@@ -499,13 +506,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
499506 }
500507 else if (this ->method == " dav_subspace" )
501508 {
502- auto ngk_pointer = psi.get_ngk_pointer ();
503509 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
504- auto hpsi_func = [hm, ngk_pointer ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
510+ auto hpsi_func = [hm, ngk_vector ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
505511 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
506512
507513 // Convert "pointer data stucture" to a psi::Psi object
508- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer );
514+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector );
509515
510516 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
511517
@@ -550,15 +556,13 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
550556 const int ld_psi = psi.get_nbasis (); // / leading dimension of psi
551557
552558 // Davidson matrix-blockvector functions
553-
554- auto ngk_pointer = psi.get_ngk_pointer ();
555559 // / wrap hpsi into lambda function, Matrix \times blockvector
556560 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
557- auto hpsi_func = [hm, ngk_pointer ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
561+ auto hpsi_func = [hm, ngk_vector ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
558562 ModuleBase::timer::tick (" David" , " hpsi_func" );
559563
560564 // Convert pointer of psi_in to a psi::Psi object
561- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer );
565+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector );
562566
563567 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
564568
0 commit comments