@@ -38,9 +38,10 @@ void WFInit<T, Device>::prepare_init(Structure_Factor* p_sf,
3838#endif
3939 pseudopot_cell_vnl* p_ppcell)
4040{
41- if (!this ->use_psiinitializer ) {
41+ if (!this ->use_psiinitializer )
42+ {
4243 return ;
43- }
44+ }
4445 // under restriction of C++11, std::unique_ptr can not be allocate via std::make_unique
4546 // use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
4647 ModuleBase::timer::tick (" WFInit" , " prepare_init" );
@@ -136,7 +137,9 @@ void WFInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
136137template <typename T, typename Device>
137138void WFInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf)
138139{
139- if (this ->use_psiinitializer ) {} // do not need to do anything because the interpolate table is unchanged
140+ if (this ->use_psiinitializer )
141+ {
142+ } // do not need to do anything because the interpolate table is unchanged
140143 else // old initialization method, used in EXX calculation
141144 {
142145 this ->p_wf ->init_after_vc (nks); // reallocate wanf2, the planewave expansion of lcao
@@ -150,95 +153,118 @@ void WFInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
150153 hamilt::Hamilt<T, Device>* p_hamilt,
151154 std::ofstream& ofs_running)
152155{
153- if (!this ->use_psiinitializer ) { return ; }
154156 ModuleBase::timer::tick (" WFInit" , " initialize_psi" );
155- // if psig is not allocated before, allocate it
156- if (!this ->psi_init ->psig_use_count ()) { this ->psi_init ->allocate (/* psig_only=*/ true ); }
157-
158- // loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
159- // like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
160- for (int ik = 0 ; ik < this ->pw_wfc ->nks ; ik++)
157+
158+ if (this ->use_psiinitializer )
161159 {
162- // ! Fix the wavefunction to initialize at given kpoint
163- psi->fix_k (ik);
160+ // if psig is not allocated before, allocate it
161+ if (!this ->psi_init ->psig_use_count ())
162+ {
163+ this ->psi_init ->allocate (/* psig_only=*/ true );
164+ }
164165
165- // ! Update Hamiltonian from other kpoint to the given one
166- p_hamilt->updateHk (ik);
166+ // loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
167+ // like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
168+ for (int ik = 0 ; ik < this ->pw_wfc ->nks ; ik++)
169+ {
170+ // ! Fix the wavefunction to initialize at given kpoint
171+ psi->fix_k (ik);
167172
168- // ! Project atomic orbitals on |k+G> planewave basis, where k is wavevector of kpoint
169- // ! and G is wavevector of the peroiodic part of the Bloch function
170- this ->psi_init ->proj_ao_onkG (ik);
173+ // ! Update Hamiltonian from other kpoint to the given one
174+ p_hamilt->updateHk (ik);
171175
172- // ! psi_initializer manages memory of psig with shared pointer,
173- // ! its access to use is shared here via weak pointer
174- // ! therefore once the psi_initializer is destructed, psig will be destructed, too
175- // ! this way, we can avoid memory leak and undefined behavior
176- std::weak_ptr<psi::Psi<T, Device>> psig = this ->psi_init ->share_psig ();
176+ // ! Project atomic orbitals on |k+G> planewave basis, where k is wavevector of kpoint
177+ // ! and G is wavevector of the peroiodic part of the Bloch function
178+ this ->psi_init ->proj_ao_onkG (ik);
177179
178- if (psig.expired ())
179- {
180- ModuleBase::WARNING_QUIT (" WFInit::initialize_psi" , " psig lifetime is expired" );
181- }
180+ // ! psi_initializer manages memory of psig with shared pointer,
181+ // ! its access to use is shared here via weak pointer
182+ // ! therefore once the psi_initializer is destructed, psig will be destructed, too
183+ // ! this way, we can avoid memory leak and undefined behavior
184+ std::weak_ptr<psi::Psi<T, Device>> psig = this ->psi_init ->share_psig ();
182185
183- // ! to use psig, we need to lock it to get a shared pointer version,
184- // ! then switch kpoint of psig to the given one
185- auto psig_ = psig.lock ();
186- // CHANGE LOG: if not lcaoinpw, the psig will only be used in psi-initialization
187- // so we can only allocate memory for one kpoint with the maximal number of pw
188- // over all kpoints, then the memory space will be always enough. Then for each
189- // kpoint, the psig is calculated in an overwrite manner.
190- const int ik_psig = (psig_->get_nk () == 1 ) ? 0 : ik;
191- psig_->fix_k (ik_psig);
192-
193- std::vector<typename GetTypeReal<T>::type> etatom (psig_->get_nbands (), 0.0 );
194-
195- // then adjust dimension from psig to psi
196- // either by matrix-multiplication or by copying-discarding
197- if (this ->psi_init ->method () != " random" )
198- {
199- // lcaoinpw and pw share the same esolver. In the future, we will have different esolver
200- if (((this ->ks_solver == " cg" ) || (this ->ks_solver == " lapack" )) && (this ->basis_type == " pw" ))
186+ if (psig.expired ())
201187 {
202- // the following function is only run serially, to be improved
203- hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init (p_hamilt,
204- psig_->get_pointer (),
205- psig_->get_nbands (),
206- psig_->get_nbasis (),
207- *(kspw_psi),
208- etatom.data ());
209- continue ;
188+ ModuleBase::WARNING_QUIT (" WFInit::initialize_psi" , " psig lifetime is expired" );
210189 }
211- else if ((this ->ks_solver == " lapack" ) && (this ->basis_type == " lcao_in_pw" ))
190+
191+ // ! to use psig, we need to lock it to get a shared pointer version,
192+ // ! then switch kpoint of psig to the given one
193+ auto psig_ = psig.lock ();
194+ // CHANGE LOG: if not lcaoinpw, the psig will only be used in psi-initialization
195+ // so we can only allocate memory for one kpoint with the maximal number of pw
196+ // over all kpoints, then the memory space will be always enough. Then for each
197+ // kpoint, the psig is calculated in an overwrite manner.
198+ const int ik_psig = (psig_->get_nk () == 1 ) ? 0 : ik;
199+ psig_->fix_k (ik_psig);
200+
201+ std::vector<typename GetTypeReal<T>::type> etatom (psig_->get_nbands (), 0.0 );
202+
203+ // then adjust dimension from psig to psi
204+ // either by matrix-multiplication or by copying-discarding
205+ if (this ->psi_init ->method () != " random" )
212206 {
213- if (ik == 0 )
207+ // lcaoinpw and pw share the same esolver. In the future, we will have different esolver
208+ if (((this ->ks_solver == " cg" ) || (this ->ks_solver == " lapack" )) && (this ->basis_type == " pw" ))
209+ {
210+ // the following function is only run serially, to be improved
211+ hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init (p_hamilt,
212+ psig_->get_pointer (),
213+ psig_->get_nbands (),
214+ psig_->get_nbasis (),
215+ *(kspw_psi),
216+ etatom.data ());
217+ continue ;
218+ }
219+ else if ((this ->ks_solver == " lapack" ) && (this ->basis_type == " lcao_in_pw" ))
214220 {
215- ofs_running << " START WAVEFUNCTION: LCAO_IN_PW, psi initialization skipped " << std::endl;
221+ if (ik == 0 )
222+ {
223+ ofs_running << " START WAVEFUNCTION: LCAO_IN_PW, psi initialization skipped " << std::endl;
224+ }
225+ continue ;
216226 }
217- continue ;
227+ // else the case is davidson
218228 }
219- // else the case is davidson
220- }
221- else
222- {
223- if (this ->ks_solver == " cg" )
229+ else
224230 {
225- hsolver::DiagoIterAssist<T, Device>::diagH_subspace (p_hamilt, *(psig_), *(kspw_psi), etatom.data ());
226- continue ;
231+ if (this ->ks_solver == " cg" )
232+ {
233+ hsolver::DiagoIterAssist<T, Device>::diagH_subspace (p_hamilt, *(psig_), *(kspw_psi), etatom.data ());
234+ continue ;
235+ }
236+ // else the case is davidson
227237 }
228- // else the case is davidson
229- }
230238
231- // for the Davidson method, we just copy the wavefunction (partially)
232- for (int iband = 0 ; iband < kspw_psi->get_nbands (); iband++)
233- {
234- for (int ibasis = 0 ; ibasis < kspw_psi->get_nbasis (); ibasis++)
239+ // for the Davidson method, we just copy the wavefunction (partially)
240+ for (int iband = 0 ; iband < kspw_psi->get_nbands (); iband++)
235241 {
236- (*(kspw_psi))(iband, ibasis) = (*psig_)(iband, ibasis);
242+ for (int ibasis = 0 ; ibasis < kspw_psi->get_nbasis (); ibasis++)
243+ {
244+ (*(kspw_psi))(iband, ibasis) = (*psig_)(iband, ibasis);
245+ }
237246 }
247+ } // end k-point loop
248+
249+ if (this ->basis_type != " lcao_in_pw" )
250+ {
251+ this ->psi_init ->deallocate_psig ();
238252 }
239- } // end k-point loop
253+ }
254+ else
255+ {
256+ for (int ik = 0 ; ik < this ->pw_wfc ->nks ; ++ik)
257+ {
258+ // ! Update Hamiltonian from other kpoint to the given one
259+ p_hamilt->updateHk (ik);
260+
261+ // ! Fix the wavefunction to initialize at given kpoint
262+ kspw_psi->fix_k (ik);
240263
241- if (this ->basis_type != " lcao_in_pw" ) { this ->psi_init ->deallocate_psig (); }
264+ // / for psi init guess!!!!
265+ hamilt::diago_PAO_in_pw_k2 (this ->ctx , ik, *kspw_psi, this ->pw_wfc , this ->p_wf , p_hamilt);
266+ }
267+ }
242268
243269 ModuleBase::timer::tick (" WFInit" , " initialize_psi" );
244270}
0 commit comments