@@ -365,10 +365,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
365365 const diag_comm_info comm_info = {this ->rank_in_pool , this ->nproc_in_pool };
366366#endif
367367
368+ auto ngk_pointer = psi.get_ngk_pointer ();
369+
368370 if (this ->method == " cg" )
369371 {
370372 // wrap the subspace_func into a lambda function
371- auto ngk_pointer = psi.get_ngk_pointer ();
372373 auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
373374 // psi_in should be a 2D tensor:
374375 // psi_in.shape() = [nbands, nbasis]
@@ -379,12 +380,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
379380 1 ,
380381 psi_in.shape ().dim_size (0 ),
381382 psi_in.shape ().dim_size (1 ),
382- ngk_pointer );
383+ nullptr );
383384 auto psi_out_wrapper = psi::Psi<T, Device>(psi_out.data <T>(),
384385 1 ,
385386 psi_out.shape ().dim_size (0 ),
386387 psi_out.shape ().dim_size (1 ),
387- ngk_pointer );
388+ nullptr );
388389 auto eigen = ct::Tensor (ct::DataTypeToEnum<Real>::value,
389390 ct::DeviceType::CpuDevice,
390391 ct::TensorShape ({psi_in.shape ().dim_size (0 )}));
@@ -414,7 +415,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
414415 1 ,
415416 ndim == 1 ? 1 : psi_in.shape ().dim_size (0 ),
416417 ndim == 1 ? psi_in.NumElements () : psi_in.shape ().dim_size (1 ),
417- ngk_pointer );
418+ nullptr );
418419 psi::Range all_bands_range (true , psi_wrapper.get_current_k (), 0 , psi_wrapper.get_nbands () - 1 );
419420 using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
420421 hpsi_info info (&psi_wrapper, all_bands_range, hpsi_out.data <T>());
@@ -473,13 +474,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
473474 {
474475 const int nband = psi.get_nbands ();
475476 const int nbasis = psi.get_nbasis ();
476- auto ngk_pointer = psi.get_ngk_pointer ();
477477 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
478478 auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
479479 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
480480
481481 // Convert "pointer data stucture" to a psi::Psi object
482- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer );
482+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, nullptr );
483483
484484 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
485485
@@ -495,13 +495,12 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
495495 }
496496 else if (this ->method == " dav_subspace" )
497497 {
498- auto ngk_pointer = psi.get_ngk_pointer ();
499498 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
500499 auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
501500 ModuleBase::timer::tick (" DavSubspace" , " hpsi_func" );
502501
503502 // Convert "pointer data stucture" to a psi::Psi object
504- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer );
503+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, nullptr );
505504
506505 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
507506
@@ -546,15 +545,13 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
546545 const int ld_psi = psi.get_nbasis (); // / leading dimension of psi
547546
548547 // Davidson matrix-blockvector functions
549-
550- auto ngk_pointer = psi.get_ngk_pointer ();
551548 // / wrap hpsi into lambda function, Matrix \times blockvector
552549 // hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
553550 auto hpsi_func = [hm, ngk_pointer](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
554551 ModuleBase::timer::tick (" David" , " hpsi_func" );
555552
556553 // Convert pointer of psi_in to a psi::Psi object
557- auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, ngk_pointer );
554+ auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1 , nvec, ld_psi, nullptr );
558555
559556 psi::Range bands_range (true , 0 , 0 , nvec - 1 );
560557
0 commit comments