@@ -378,10 +378,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
378378 ngk_vector[i] = ngk_pointer[i];
379379 }
380380
381+ const int cur_nbasis = psi.get_current_nbas ();
382+
381383 if (this ->method == " cg" )
382384 {
383385 // wrap the subspace_func into a lambda function
384- auto subspace_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
386+ auto subspace_func = [hm, ngk_vector, cur_nbasis ](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
385387 // psi_in should be a 2D tensor:
386388 // psi_in.shape() = [nbands, nbasis]
387389 const auto ndim = psi_in.shape ().ndim ();
@@ -391,12 +393,14 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
391393 1 ,
392394 psi_in.shape ().dim_size (0 ),
393395 psi_in.shape ().dim_size (1 ),
394- ngk_vector);
396+ ngk_vector,
397+ cur_nbasis);
395398 auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data <T>(),
396399 1 ,
397400 psi_out.shape ().dim_size (0 ),
398401 psi_out.shape ().dim_size (1 ),
399- ngk_vector);
402+ ngk_vector,
403+ cur_nbasis);
400404 auto eigen = ct::Tensor (ct::DataTypeToEnum<Real>::value,
401405 ct::DeviceType::CpuDevice,
402406 ct::TensorShape ({psi_in.shape ().dim_size (0 )}));
@@ -415,7 +419,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
415419 using ct_Device = typename ct::PsiToContainer<Device>::type;
416420
417421 // wrap the hpsi_func and spsi_func into a lambda function
418- auto hpsi_func = [hm, ngk_vector](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
422+ auto hpsi_func = [hm, ngk_vector, cur_nbasis ](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) {
419423 ModuleBase::timer::tick (" DiagoCG_New" , " hpsi_func" );
420424 // psi_in should be a 2D tensor:
421425 // psi_in.shape() = [nbands, nbasis]
@@ -426,7 +430,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
426430 1 ,
427431 ndim == 1 ? 1 : psi_in.shape ().dim_size (0 ),
428432 ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ),
429- ngk_vector);
433+ ngk_vector,
434+ cur_nbasis);
430435 psi::Range all_bands_range (true , psi_wrapper.get_current_k (), 0 , psi_wrapper.get_nbands () - 1 );
431436 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
432437 hpsi_info info (&psi_wrapper, all_bands_range, hpsi_out.data <T>());
@@ -486,11 +491,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
486491 const int nband = psi.get_nbands ();
487492 const int nbasis = psi.get_nbasis ();
488493 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
489- auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
494+ auto hpsi_func = [hm, ngk_vector, cur_nbasis ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
490495 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
491496
492497 // Convert "pointer data stucture" to a psi::Psi object
493- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector);
498+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector, cur_nbasis );
494499
495500 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
496501
@@ -507,11 +512,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
507512 else if (this ->method == " dav_subspace" )
508513 {
509514 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
510- auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
515+ auto hpsi_func = [hm, ngk_vector, cur_nbasis ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
511516 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
512517
513518 // Convert "pointer data stucture" to a psi::Psi object
514- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector);
519+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector, cur_nbasis );
515520
516521 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
517522
@@ -558,11 +563,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
558563 // Davidson matrix-blockvector functions
559564 // / wrap hpsi into lambda function, Matrix \times blockvector
560565 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
561- auto hpsi_func = [hm, ngk_vector](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
566+ auto hpsi_func = [hm, ngk_vector, cur_nbasis ](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
562567 ModuleBase::timer::tick (" David" , " hpsi_func" );
563568
564569 // Convert pointer of psi_in to a psi::Psi object
565- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector);
570+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_vector, cur_nbasis );
566571
567572 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
568573
0 commit comments