44#include " module_base/timer.h"
55#include " module_base/tool_quit.h"
66#include " module_hsolver/diago_iter_assist.h"
7+ #include " module_parameter/parameter.h"
78#include " module_psi/psi_initializer_atomic.h"
89#include " module_psi/psi_initializer_atomic_random.h"
910#include " module_psi/psi_initializer_nao.h"
@@ -38,9 +39,10 @@ void WFInit<T, Device>::prepare_init(Structure_Factor* p_sf,
3839#endif
3940 pseudopot_cell_vnl* p_ppcell)
4041{
41- if (!this ->use_psiinitializer ) {
42+ if (!this ->use_psiinitializer )
43+ {
4244 return ;
43- }
45+ }
4446 // under restriction of C++11, std::unique_ptr can not be allocate via std::make_unique
4547 // use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
4648 ModuleBase::timer::tick (" WFInit" , " prepare_init" );
@@ -136,7 +138,9 @@ void WFInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
136138template <typename T, typename Device>
137139void WFInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf)
138140{
139- if (this ->use_psiinitializer ) {} // do not need to do anything because the interpolate table is unchanged
141+ if (this ->use_psiinitializer )
142+ {
143+ } // do not need to do anything because the interpolate table is unchanged
140144 else // old initialization method, used in EXX calculation
141145 {
142146 this ->p_wf ->init_after_vc (nks); // reallocate wanf2, the planewave expansion of lcao
@@ -150,95 +154,121 @@ void WFInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
150154 hamilt::Hamilt<T, Device>* p_hamilt,
151155 std::ofstream& ofs_running)
152156{
153- if (!this ->use_psiinitializer ) { return ; }
154157 ModuleBase::timer::tick (" WFInit" , " initialize_psi" );
155- // if psig is not allocated before, allocate it
156- if (!this ->psi_init ->psig_use_count ()) { this ->psi_init ->allocate (/* psig_only=*/ true ); }
157-
158- // loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
159- // like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
160- for (int ik = 0 ; ik < this ->pw_wfc ->nks ; ik++)
158+
159+ if (PARAM.inp .psi_initializer )
161160 {
162- // ! Fix the wavefunction to initialize at given kpoint
163- psi->fix_k (ik);
161+ // if psig is not allocated before, allocate it
162+ if (!this ->psi_init ->psig_use_count ())
163+ {
164+ this ->psi_init ->allocate (/* psig_only=*/ true );
165+ }
164166
165- // ! Update Hamiltonian from other kpoint to the given one
166- p_hamilt->updateHk (ik);
167+ // loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
168+ // like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
169+ for (int ik = 0 ; ik < this ->pw_wfc ->nks ; ik++)
170+ {
171+ // ! Fix the wavefunction to initialize at given kpoint
172+ psi->fix_k (ik);
167173
168- // ! Project atomic orbitals on |k+G> planewave basis, where k is wavevector of kpoint
169- // ! and G is wavevector of the peroiodic part of the Bloch function
170- this ->psi_init ->proj_ao_onkG (ik);
174+ // ! Update Hamiltonian from other kpoint to the given one
175+ p_hamilt->updateHk (ik);
171176
172- // ! psi_initializer manages memory of psig with shared pointer,
173- // ! its access to use is shared here via weak pointer
174- // ! therefore once the psi_initializer is destructed, psig will be destructed, too
175- // ! this way, we can avoid memory leak and undefined behavior
176- std::weak_ptr<psi::Psi<T, Device>> psig = this ->psi_init ->share_psig ();
177+ // ! Project atomic orbitals on |k+G> planewave basis, where k is wavevector of kpoint
178+ // ! and G is wavevector of the peroiodic part of the Bloch function
179+ this ->psi_init ->proj_ao_onkG (ik);
177180
178- if (psig.expired ())
179- {
180- ModuleBase::WARNING_QUIT (" WFInit::initialize_psi" , " psig lifetime is expired" );
181- }
181+ // ! psi_initializer manages memory of psig with shared pointer,
182+ // ! its access to use is shared here via weak pointer
183+ // ! therefore once the psi_initializer is destructed, psig will be destructed, too
184+ // ! this way, we can avoid memory leak and undefined behavior
185+ std::weak_ptr<psi::Psi<T, Device>> psig = this ->psi_init ->share_psig ();
182186
183- // ! to use psig, we need to lock it to get a shared pointer version,
184- // ! then switch kpoint of psig to the given one
185- auto psig_ = psig.lock ();
186- // CHANGE LOG: if not lcaoinpw, the psig will only be used in psi-initialization
187- // so we can only allocate memory for one kpoint with the maximal number of pw
188- // over all kpoints, then the memory space will be always enough. Then for each
189- // kpoint, the psig is calculated in an overwrite manner.
190- const int ik_psig = (psig_->get_nk () == 1 ) ? 0 : ik;
191- psig_->fix_k (ik_psig);
192-
193- std::vector<typename GetTypeReal<T>::type> etatom (psig_->get_nbands (), 0.0 );
194-
195- // then adjust dimension from psig to psi
196- // either by matrix-multiplication or by copying-discarding
197- if (this ->psi_init ->method () != " random" )
198- {
199- // lcaoinpw and pw share the same esolver. In the future, we will have different esolver
200- if (((this ->ks_solver == " cg" ) || (this ->ks_solver == " lapack" )) && (this ->basis_type == " pw" ))
187+ if (psig.expired ())
201188 {
202- // the following function is only run serially, to be improved
203- hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init (p_hamilt,
204- psig_->get_pointer (),
205- psig_->get_nbands (),
206- psig_->get_nbasis (),
207- *(kspw_psi),
208- etatom.data ());
209- continue ;
189+ ModuleBase::WARNING_QUIT (" WFInit::initialize_psi" , " psig lifetime is expired" );
210190 }
211- else if ((this ->ks_solver == " lapack" ) && (this ->basis_type == " lcao_in_pw" ))
191+
192+ // ! to use psig, we need to lock it to get a shared pointer version,
193+ // ! then switch kpoint of psig to the given one
194+ auto psig_ = psig.lock ();
195+ // CHANGE LOG: if not lcaoinpw, the psig will only be used in psi-initialization
196+ // so we can only allocate memory for one kpoint with the maximal number of pw
197+ // over all kpoints, then the memory space will be always enough. Then for each
198+ // kpoint, the psig is calculated in an overwrite manner.
199+ const int ik_psig = (psig_->get_nk () == 1 ) ? 0 : ik;
200+ psig_->fix_k (ik_psig);
201+
202+ std::vector<typename GetTypeReal<T>::type> etatom (psig_->get_nbands (), 0.0 );
203+
204+ // then adjust dimension from psig to psi
205+ // either by matrix-multiplication or by copying-discarding
206+ if (this ->psi_init ->method () != " random" )
212207 {
213- if (ik == 0 )
208+ // lcaoinpw and pw share the same esolver. In the future, we will have different esolver
209+ if (((this ->ks_solver == " cg" ) || (this ->ks_solver == " lapack" )) && (this ->basis_type == " pw" ))
210+ {
211+ // the following function is only run serially, to be improved
212+ hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init (p_hamilt,
213+ psig_->get_pointer (),
214+ psig_->get_nbands (),
215+ psig_->get_nbasis (),
216+ *(kspw_psi),
217+ etatom.data ());
218+ continue ;
219+ }
220+ else if ((this ->ks_solver == " lapack" ) && (this ->basis_type == " lcao_in_pw" ))
214221 {
215- ofs_running << " START WAVEFUNCTION: LCAO_IN_PW, psi initialization skipped " << std::endl;
222+ if (ik == 0 )
223+ {
224+ ofs_running << " START WAVEFUNCTION: LCAO_IN_PW, psi initialization skipped " << std::endl;
225+ }
226+ continue ;
216227 }
217- continue ;
228+ // else the case is davidson
218229 }
219- // else the case is davidson
220- }
221- else
222- {
223- if (this ->ks_solver == " cg" )
230+ else
224231 {
225- hsolver::DiagoIterAssist<T, Device>::diagH_subspace (p_hamilt, *(psig_), *(kspw_psi), etatom.data ());
226- continue ;
232+ if (this ->ks_solver == " cg" )
233+ {
234+ hsolver::DiagoIterAssist<T, Device>::diagH_subspace (p_hamilt, *(psig_), *(kspw_psi), etatom.data ());
235+ continue ;
236+ }
237+ // else the case is davidson
227238 }
228- // else the case is davidson
229- }
230239
231- // for the Davidson method, we just copy the wavefunction (partially)
232- for (int iband = 0 ; iband < kspw_psi->get_nbands (); iband++)
233- {
234- for (int ibasis = 0 ; ibasis < kspw_psi->get_nbasis (); ibasis++)
240+ // for the Davidson method, we just copy the wavefunction (partially)
241+ for (int iband = 0 ; iband < kspw_psi->get_nbands (); iband++)
235242 {
236- (*(kspw_psi))(iband, ibasis) = (*psig_)(iband, ibasis);
243+ for (int ibasis = 0 ; ibasis < kspw_psi->get_nbasis (); ibasis++)
244+ {
245+ (*(kspw_psi))(iband, ibasis) = (*psig_)(iband, ibasis);
246+ }
237247 }
248+ } // end k-point loop
249+
250+ if (this ->basis_type != " lcao_in_pw" )
251+ {
252+ this ->psi_init ->deallocate_psig ();
238253 }
239- } // end k-point loop
254+ }
255+ else
256+ {
257+ // if (PARAM.inp.basis_type == "pw")
258+ // {
259+ // for (int ik = 0; ik < this->pw_wfc->nks; ++ik)
260+ // {
261+ // //! Update Hamiltonian from other kpoint to the given one
262+ // p_hamilt->updateHk(ik);
263+
264+ // //! Fix the wavefunction to initialize at given kpoint
265+ // kspw_psi->fix_k(ik);
240266
241- if (this ->basis_type != " lcao_in_pw" ) { this ->psi_init ->deallocate_psig (); }
267+ // /// for psi init guess!!!!
268+ // hamilt::diago_PAO_in_pw_k2(this->ctx, ik, *kspw_psi, this->pw_wfc, this->p_wf, p_hamilt);
269+ // }
270+ // }
271+ }
242272
243273 ModuleBase::timer::tick (" WFInit" , " initialize_psi" );
244274}
0 commit comments