55#include " module_base/tool_quit.h"
66#include " module_hsolver/diago_iter_assist.h"
77#include " module_parameter/parameter.h"
8+ #include " module_psi/psi_initializer_random.h"
89#include " module_psi/psi_initializer_atomic.h"
910#include " module_psi/psi_initializer_atomic_random.h"
1011#include " module_psi/psi_initializer_nao.h"
1112#include " module_psi/psi_initializer_nao_random.h"
12- #include " module_psi/psi_initializer_random.h"
1313namespace psi
1414{
1515
@@ -46,27 +46,27 @@ void PSIInit<T, Device>::prepare_init(Structure_Factor* p_sf,
4646 ModuleBase::timer::tick (" PSIInit" , " prepare_init" );
4747 if ((this ->init_wfc .substr (0 , 6 ) == " atomic" ) && (p_ucell->natomwfc == 0 ))
4848 {
49- this ->psi_init = std::unique_ptr<psi_initializer <T, Device>>(new psi_initializer_random <T, Device>());
49+ this ->psi_init = std::unique_ptr<PsiInitializer <T, Device>>(new PsiInitializerRandom <T, Device>());
5050 }
5151 else if (this ->init_wfc == " atomic" )
5252 {
53- this ->psi_init = std::unique_ptr<psi_initializer <T, Device>>(new psi_initializer_atomic <T, Device>());
53+ this ->psi_init = std::unique_ptr<PsiInitializer <T, Device>>(new PsiInitializerAtomic <T, Device>());
5454 }
5555 else if (this ->init_wfc == " random" )
5656 {
57- this ->psi_init = std::unique_ptr<psi_initializer <T, Device>>(new psi_initializer_random <T, Device>());
57+ this ->psi_init = std::unique_ptr<PsiInitializer <T, Device>>(new PsiInitializerRandom <T, Device>());
5858 }
5959 else if (this ->init_wfc == " nao" )
6060 {
61- this ->psi_init = std::unique_ptr<psi_initializer <T, Device>>(new psi_initializer_nao <T, Device>());
61+ this ->psi_init = std::unique_ptr<PsiInitializer <T, Device>>(new PsiInitializerNAO <T, Device>());
6262 }
6363 else if (this ->init_wfc == " atomic+random" )
6464 {
65- this ->psi_init = std::unique_ptr<psi_initializer <T, Device>>(new psi_initializer_atomic_random <T, Device>());
65+ this ->psi_init = std::unique_ptr<PsiInitializer <T, Device>>(new PsiInitializerAtomicRandom <T, Device>());
6666 }
6767 else if (this ->init_wfc == " nao+random" )
6868 {
69- this ->psi_init = std::unique_ptr<psi_initializer <T, Device>>(new psi_initializer_nao_random <T, Device>());
69+ this ->psi_init = std::unique_ptr<PsiInitializer <T, Device>>(new PsiInitializerNAORandom <T, Device>());
7070 }
7171 else
7272 {
@@ -110,7 +110,7 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
110110 // is not ready yet.
111111 if (this ->use_psiinitializer ) // new method
112112 {
113- // psi_initializer drag initialization of pw wavefunction out of HSolver, make psi
113+ // PsiInitializer drag initialization of pw wavefunction out of HSolver, make psi
114114 // initialization decoupled with HSolver (diagonalization) procedure.
115115 // However, due to EXX is hard to maintain, we still use the old method for EXX.
116116 // LCAOINPW in version >= 3.5.0 uses this new method.
@@ -147,9 +147,10 @@ void PSIInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf, pseud
147147 }
148148}
149149
150+ // in the following function, the psi on Device will be initialized with the CPU psi
150151template <typename T, typename Device>
151- void PSIInit<T, Device>::initialize_psi(Psi<std::complex <double >>* psi,
152- psi::Psi<T, Device>* kspw_psi,
152+ void PSIInit<T, Device>::initialize_psi(Psi<std::complex <double >>* psi, // the one always on CPU
153+ psi::Psi<T, Device>* kspw_psi, // the one may be on GPU. In CPU case, it is the same as psi
153154 hamilt::Hamilt<T, Device>* p_hamilt,
154155 const pseudopot_cell_vnl& nlpp,
155156 std::ofstream& ofs_running,
@@ -169,7 +170,8 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
169170 // like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
170171 for (int ik = 0 ; ik < this ->pw_wfc ->nks ; ik++)
171172 {
172- // ! Fix the wavefunction to initialize at given kpoint
173+ // ! Fix the wavefunction to initialize at given kpoint.
174+ // This will fix the kpoint for CPU case. For GPU, we should additionally call fix_k for kspw_psi
173175 psi->fix_k (ik);
174176
175177 // ! Update Hamiltonian from other kpoint to the given one
@@ -179,20 +181,20 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
179181 // ! and G is wavevector of the peroiodic part of the Bloch function
180182 this ->psi_init ->proj_ao_onkG (ik);
181183
182- // ! psi_initializer manages memory of psig with shared pointer,
184+ // ! PsiInitializer manages memory of psig with shared pointer,
183185 // ! its access to use is shared here via weak pointer
184- // ! therefore once the psi_initializer is destructed, psig will be destructed, too
186+ // ! therefore once the PsiInitializer is destructed, psig will be destructed, too
185187 // ! this way, we can avoid memory leak and undefined behavior
186- std::weak_ptr<psi::Psi<T, Device>> psig = this ->psi_init ->share_psig ();
187-
188- if (psig.expired ())
188+ // std::weak_ptr<psi::Psi<T, Device>> psig = this->psi_init->share_psig();
189+ psi::Psi<T, Device>* psig_ = this -> psi_init -> share_psig ();
190+ if (/* psig.expired()*/ psig_ == nullptr )
189191 {
190192 ModuleBase::WARNING_QUIT (" PSIInit::initialize_psi" , " psig lifetime is expired" );
191193 }
192194
193195 // ! to use psig, we need to lock it to get a shared pointer version,
194196 // ! then switch kpoint of psig to the given one
195- auto psig_ = psig.lock ();
197+ // auto psig_ = psig.lock();
196198 // CHANGE LOG: if not lcaoinpw, the psig will only be used in psi-initialization
197199 // so we can only allocate memory for one kpoint with the maximal number of pw
198200 // over all kpoints, then the memory space will be always enough. Then for each
@@ -210,6 +212,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
210212 if (((this ->ks_solver == " cg" ) || (this ->ks_solver == " lapack" )) && (this ->basis_type == " pw" ))
211213 {
212214 // the following function is only run serially, to be improved
215+ // For GPU: this psig_ should be on GPU before calling the following function
213216 hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init (p_hamilt,
214217 psig_->get_pointer (),
215218 psig_->get_nbands (),
@@ -218,6 +221,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
218221 etatom.data ());
219222 continue ;
220223 }
224+ // do nothing in LCAO_IN_PW case because psig is used to do transformation instead of initialization
221225 else if ((this ->ks_solver == " lapack" ) && (this ->basis_type == " lcao_in_pw" ))
222226 {
223227 if (ik == 0 )
@@ -239,6 +243,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
239243 }
240244
241245 // for the Davidson method, we just copy the wavefunction (partially)
246+ // For GPU: although this is simply the copy operation, if GPU present, this should be a data sending operation
242247 for (int iband = 0 ; iband < kspw_psi->get_nbands (); iband++)
243248 {
244249 for (int ibasis = 0 ; ibasis < kspw_psi->get_nbasis (); ibasis++)
@@ -248,7 +253,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
248253 }
249254 } // end k-point loop
250255
251- if (this ->basis_type != " lcao_in_pw" )
256+ if (this ->basis_type != " lcao_in_pw" ) // if not LCAO_IN_PW case, we can release the memory of psig after initailization is done.
252257 {
253258 this ->psi_init ->deallocate_psig ();
254259 }
0 commit comments