1818#include " module_base/memory.h"
1919#include " module_base/module_device/device.h"
2020#include " module_elecstate/elecstate_pw.h"
21+ #include " module_elecstate/elecstate_pw_sdft.h"
2122#include " module_hamilt_general/module_vdw/vdw.h"
2223#include " module_hamilt_pw/hamilt_pwdft/elecond.h"
2324#include " module_hamilt_pw/hamilt_pwdft/hamilt_pw.h"
@@ -136,14 +137,29 @@ void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCel
136137 // 3) initialize ElecState,
137138 if (this ->pelec == nullptr )
138139 {
139- this ->pelec = new elecstate::ElecStatePW<T, Device>(this ->pw_wfc ,
140- &(this ->chr ),
141- &(this ->kv ),
142- &ucell,
143- &GlobalC::ppcell,
144- this ->pw_rhod ,
145- this ->pw_rho ,
146- this ->pw_big );
140+ if (inp.esolver_type == " sdft" )
141+ {
142+ // ! SDFT only supports double precision currently
143+ this ->pelec = new elecstate::ElecStatePW_SDFT<std::complex <double >, Device>(this ->pw_wfc ,
144+ &(this ->chr ),
145+ &(this ->kv ),
146+ &ucell,
147+ &(GlobalC::ppcell),
148+ this ->pw_rhod ,
149+ this ->pw_rho ,
150+ this ->pw_big );
151+ }
152+ else
153+ {
154+ this ->pelec = new elecstate::ElecStatePW<T, Device>(this ->pw_wfc ,
155+ &(this ->chr ),
156+ &(this ->kv ),
157+ &ucell,
158+ &GlobalC::ppcell,
159+ this ->pw_rhod ,
160+ this ->pw_rho ,
161+ this ->pw_big );
162+ }
147163 }
148164
149165 // ! 4) inititlize the charge density.
@@ -165,12 +181,12 @@ void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCel
165181 }
166182
167183 // ! 7) prepare some parameters for electronic wave functions initilization
168- this ->p_wf_init = new psi::WFInit <T, Device>(PARAM.inp .init_wfc ,
169- PARAM.inp .ks_solver ,
170- PARAM.inp .basis_type ,
171- PARAM.inp .psi_initializer ,
172- &this ->wf ,
173- this ->pw_wfc );
184+ this ->p_wf_init = new psi::PSIInit <T, Device>(PARAM.inp .init_wfc ,
185+ PARAM.inp .ks_solver ,
186+ PARAM.inp .basis_type ,
187+ PARAM.inp .psi_initializer ,
188+ &this ->wf ,
189+ this ->pw_wfc );
174190 this ->p_wf_init ->prepare_init (&(this ->sf ),
175191 &ucell,
176192 1 ,
@@ -180,8 +196,39 @@ void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCel
180196#endif
181197 &GlobalC::ppcell);
182198
183- // ! 8) setup global classes
184- this ->Init_GlobalC (inp, ucell, GlobalC::ppcell);
199+ if (this ->psi != nullptr )
200+ {
201+ delete this ->psi ;
202+ }
203+
204+ // ! init pseudopotential
205+ GlobalC::ppcell.init (ucell.ntype , &this ->sf , this ->pw_wfc );
206+
207+ // ! initalize local pseudopotential
208+ GlobalC::ppcell.init_vloc (GlobalC::ppcell.vloc , this ->pw_rhod );
209+ ModuleBase::GlobalFunc::DONE (GlobalV::ofs_running, " LOCAL POTENTIAL" );
210+
211+ // ! Initalize non-local pseudopotential
212+ GlobalC::ppcell.init_vnl (ucell, this ->pw_rhod );
213+ ModuleBase::GlobalFunc::DONE (GlobalV::ofs_running, " NON-LOCAL POTENTIAL" );
214+
215+ // ! Allocate psi
216+ this ->p_wf_init ->allocate_psi (this ->psi ,
217+ this ->kv .get_nkstot (),
218+ this ->kv .get_nks (),
219+ this ->kv .ngk .data (),
220+ this ->pw_wfc ->npwk_max ,
221+ &(this ->sf ));
222+
223+ this ->kspw_psi = PARAM.inp .device == " gpu" || PARAM.inp .precision == " single"
224+ ? new psi::Psi<T, Device>(this ->psi [0 ])
225+ : reinterpret_cast <psi::Psi<T, Device>*>(this ->psi );
226+
227+ if (PARAM.inp .precision == " single" )
228+ {
229+ ModuleBase::Memory::record (" Psi_single" , sizeof (T) * this ->psi [0 ].size ());
230+ }
231+ ModuleBase::GlobalFunc::DONE (GlobalV::ofs_running, " INIT BASIS" );
185232
186233 // ! 9) setup occupations
187234 if (PARAM.inp .ocp )
@@ -500,7 +547,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(const int istep, int& iter)
500547 }
501548
502549 // 4) Print out electronic wavefunctions
503- if (this -> wf . out_wfc_pw == 1 || this -> wf .out_wfc_pw == 2 )
550+ if (PARAM. inp . out_wfc_pw == 1 || PARAM. inp .out_wfc_pw == 2 )
504551 {
505552 std::stringstream ssw;
506553 ssw << PARAM.globalv .global_out_dir << " WAVEFUNC" ;
@@ -526,7 +573,7 @@ void ESolver_KS_PW<T, Device>::after_scf(const int istep)
526573 ESolver_KS<T, Device>::after_scf (istep);
527574
528575 // 3) output wavefunctions
529- if (this -> wf . out_wfc_pw == 1 || this -> wf .out_wfc_pw == 2 )
576+ if (PARAM. inp . out_wfc_pw == 1 || PARAM. inp .out_wfc_pw == 2 )
530577 {
531578 std::stringstream ssw;
532579 ssw << PARAM.globalv .global_out_dir << " WAVEFUNC" ;
@@ -774,7 +821,7 @@ void ESolver_KS_PW<T, Device>::after_all_runners()
774821 }
775822
776823 // ! 6) Print out electronic wave functions in real space
777- if (this -> wf .out_wfc_r == 1 ) // Peize Lin add 2021.11.21
824+ if (PARAM. inp .out_wfc_r == 1 ) // Peize Lin add 2021.11.21
778825 {
779826 ModuleIO::write_psi_r_1 (this ->psi [0 ], this ->pw_wfc , " wfc_realspace" , true , this ->kv );
780827 }
0 commit comments