@@ -49,17 +49,9 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
4949 // delete Hamilt
5050 this ->deallocate_hamilt ();
5151
52- if (PARAM.inp .device == " gpu" || PARAM.inp .precision == " single" )
53- {
54- delete this ->kspw_psi ;
55- }
56- if (PARAM.inp .precision == " single" )
57- {
58- delete this ->__kspw_psi ;
59- }
52+ // mohan add 2025-10-12
53+ this ->stp .clean ();
6054
61- delete this ->psi ;
62- delete this ->p_psi_init ;
6355}
6456
6557template <typename T, typename Device>
@@ -89,18 +81,7 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
8981 this ->locpp , this ->ppcell , this ->vsep_cell , this ->pw_wfc , this ->pw_rho ,
9082 this ->pw_rhod , this ->pw_big , this ->solvent , inp);
9183
92- // ! Allocate and initialize psi
93- this ->p_psi_init = new psi::PSIInit<T, Device>(inp.init_wfc ,
94- inp.ks_solver , inp.basis_type , GlobalV::MY_RANK, ucell,
95- this ->sf , this ->kv , this ->ppcell , *this ->pw_wfc );
96-
97- allocate_psi (this ->psi , this ->kv .get_nks (), this ->kv .ngk , PARAM.globalv .nbands_l , this ->pw_wfc ->npwk_max );
98-
99- this ->p_psi_init ->prepare_init (inp.pw_seed );
100-
101- this ->kspw_psi = inp.device == " gpu" || inp.precision == " single"
102- ? new psi::Psi<T, Device>(this ->psi [0 ])
103- : reinterpret_cast <psi::Psi<T, Device>*>(this ->psi );
84+ this ->stp .before_runner (ucell, this ->kv , this ->sf , *this ->pw_wfc , this ->ppcell , PARAM.inp );
10485
10586 ModuleBase::GlobalFunc::DONE (GlobalV::ofs_running, " INIT BASIS" );
10687
@@ -142,7 +123,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
142123
143124 this ->pw_wfc ->collect_local_pw (PARAM.inp .erf_ecut , PARAM.inp .erf_height , PARAM.inp .erf_sigma );
144125
145- this ->p_psi_init ->prepare_init (PARAM.inp .pw_seed );
126+ this ->stp . p_psi_init ->prepare_init (PARAM.inp .pw_seed );
146127 }
147128
148129 // ! Init Hamiltonian (cell changed)
@@ -156,14 +137,10 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
156137 // ! Setup potentials (local, non-local, sc, +U, DFT-1/2)
157138 pw::setup_pot (istep, ucell, this ->kv , this ->sf , this ->pelec , this ->Pgrid ,
158139 this ->chr , this ->locpp , this ->ppcell , this ->vsep_cell ,
159- this ->kspw_psi , this ->p_hamilt , this ->pw_wfc , this ->pw_rhod , PARAM.inp );
140+ this ->stp . psi_t , this ->p_hamilt , this ->pw_wfc , this ->pw_rhod , PARAM.inp );
160141
161- // ! Initialize wave functions
162- if (!this ->already_initpsi )
163- {
164- this ->p_psi_init ->initialize_psi (this ->psi , this ->kspw_psi , this ->p_hamilt , GlobalV::ofs_running);
165- this ->already_initpsi = true ;
166- }
142+
143+ this ->stp .init (this ->p_hamilt );
167144
168145 // ! Exx calculations
169146 if (PARAM.inp .calculation == " scf" || PARAM.inp .calculation == " relax"
@@ -173,7 +150,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
173150 {
174151 auto hamilt_pw = reinterpret_cast <hamilt::HamiltPW<T, Device>*>(this ->p_hamilt );
175152 hamilt_pw->set_exx_helper (exx_helper);
176- exx_helper.set_psi (kspw_psi );
153+ exx_helper.set_psi (this -> stp . psi_t );
177154 }
178155 }
179156
@@ -202,7 +179,7 @@ void ESolver_KS_PW<T, Device>::iter_init(UnitCell& ucell, const int istep, const
202179 // new DFT+U method will calculate energy when evaluating the Hamiltonian
203180 if (dftu->omc != 2 )
204181 {
205- dftu->cal_occ_pw (iter, this ->kspw_psi , this ->pelec ->wg , ucell, PARAM.inp .mixing_beta );
182+ dftu->cal_occ_pw (iter, this ->stp . psi_t , this ->pelec ->wg , ucell, PARAM.inp .mixing_beta );
206183 }
207184 dftu->output (ucell);
208185 }
@@ -271,7 +248,7 @@ void ESolver_KS_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, const int iste
271248 PARAM.inp .use_k_continuity );
272249
273250 hsolver_pw_obj.solve (this ->p_hamilt ,
274- this ->kspw_psi [0 ],
251+ this ->stp . psi_t [0 ],
275252 this ->pelec ,
276253 this ->pelec ->ekb .c ,
277254 GlobalV::RANK_IN_POOL,
@@ -316,7 +293,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
316293 // Related to EXX
317294 if (GlobalC::exx_info.info_global .cal_exx && !exx_helper.op_exx ->first_iter )
318295 {
319- this ->pelec ->set_exx (exx_helper.cal_exx_energy (kspw_psi ));
296+ this ->pelec ->set_exx (exx_helper.cal_exx_energy (this -> stp . psi_t ));
320297 }
321298
322299 // deband is calculated from "output" charge density
@@ -347,12 +324,12 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
347324 double dexx = 0.0 ;
348325 if (PARAM.inp .exx_thr_type == " energy" )
349326 {
350- dexx = exx_helper.cal_exx_energy (this ->kspw_psi );
327+ dexx = exx_helper.cal_exx_energy (this ->stp . psi_t );
351328 }
352- exx_helper.set_psi (this ->kspw_psi );
329+ exx_helper.set_psi (this ->stp . psi_t );
353330 if (PARAM.inp .exx_thr_type == " energy" )
354331 {
355- dexx -= exx_helper.cal_exx_energy (this ->kspw_psi );
332+ dexx -= exx_helper.cal_exx_energy (this ->stp . psi_t );
356333 // std::cout << "dexx = " << dexx << std::endl;
357334 }
358335 bool conv_ene = std::abs (dexx) < PARAM.inp .exx_ene_thr ;
@@ -373,7 +350,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
373350 }
374351 else
375352 {
376- exx_helper.set_psi (this ->kspw_psi );
353+ exx_helper.set_psi (this ->stp . psi_t );
377354 }
378355 }
379356
@@ -394,7 +371,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
394371 }
395372
396373 // the output quantities
397- ModuleIO::ctrl_iter_pw (istep, iter, conv_esolver, this ->psi ,
374+ ModuleIO::ctrl_iter_pw (istep, iter, conv_esolver, this ->stp . psi_cpu ,
398375 this ->kv , this ->pw_wfc , PARAM.inp );
399376}
400377
@@ -409,24 +386,16 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
409386 // sunliang 2025-04-10
410387 if (PARAM.inp .out_elf [0 ] > 0 )
411388 {
412- this ->ESolver_KS <T, Device>::psi = new psi::Psi<T>(this ->psi [0 ]);
389+ this ->ESolver_KS <T, Device>::psi = new psi::Psi<T>(this ->stp . psi_cpu [0 ]);
413390 }
414391
415392 // Call 'after_scf' of ESolver_KS
416393 ESolver_KS<T, Device>::after_scf (ucell, istep, conv_esolver);
417394
418- // Transfer data from GPU to CPU in pw basis
419- if (this ->device == base_device::GpuDevice)
420- {
421- castmem_2d_d2h_op ()(this ->psi [0 ].get_pointer () - this ->psi [0 ].get_psi_bias (),
422- this ->kspw_psi [0 ].get_pointer () - this ->kspw_psi [0 ].get_psi_bias (),
423- this ->psi [0 ].size ());
424- }
425-
426395 // Output quantities
427396 ModuleIO::ctrl_scf_pw<T, Device>(istep, ucell, this ->pelec , this ->chr , this ->kv , this ->pw_wfc ,
428- this ->pw_rho , this ->pw_rhod , this ->pw_big , this ->psi , this -> kspw_psi ,
429- this ->__kspw_psi , this ->ctx , this ->Pgrid , PARAM.inp );
397+ this ->pw_rho , this ->pw_rhod , this ->pw_big , this ->stp ,
398+ this ->ctx , this ->device , this ->Pgrid , PARAM.inp );
430399
431400 ModuleBase::timer::tick (" ESolver_KS_PW" , " after_scf" );
432401}
@@ -442,39 +411,25 @@ void ESolver_KS_PW<T, Device>::cal_force(UnitCell& ucell, ModuleBase::matrix& fo
442411{
443412 Forces<double , Device> ff (ucell.nat );
444413
445- if (this ->__kspw_psi != nullptr && PARAM.inp .precision == " single" )
446- {
447- delete reinterpret_cast <psi::Psi<std::complex <double >, Device>*>(this ->__kspw_psi );
448- }
449-
450- // Refresh __kspw_psi
451- this ->__kspw_psi = PARAM.inp .precision == " single"
452- ? new psi::Psi<std::complex <double >, Device>(this ->kspw_psi [0 ])
453- : reinterpret_cast <psi::Psi<std::complex <double >, Device>*>(this ->kspw_psi );
414+ // mohan add 2025-10-12
415+ this ->stp .update_psi_d ();
454416
455417 // Calculate forces
456418 ff.cal_force (ucell, force, *this ->pelec , this ->pw_rhod , &ucell.symm ,
457419 &this ->sf , this ->solvent , &this ->locpp , &this ->ppcell ,
458- &this ->kv , this ->pw_wfc , this ->__kspw_psi );
420+ &this ->kv , this ->pw_wfc , this ->stp . psi_d );
459421}
460422
461423template <typename T, typename Device>
462424void ESolver_KS_PW<T, Device>::cal_stress(UnitCell& ucell, ModuleBase::matrix& stress)
463425{
464426 Stress_PW<double , Device> ss (this ->pelec );
465427
466- if (this ->__kspw_psi != nullptr && PARAM.inp .precision == " single" )
467- {
468- delete reinterpret_cast <psi::Psi<std::complex <double >, Device>*>(this ->__kspw_psi );
469- }
470-
471- // Refresh __kspw_psi
472- this ->__kspw_psi = PARAM.inp .precision == " single"
473- ? new psi::Psi<std::complex <double >, Device>(this ->kspw_psi [0 ])
474- : reinterpret_cast <psi::Psi<std::complex <double >, Device>*>(this ->kspw_psi );
428+ // mohan add 2025-10-12
429+ this ->stp .update_psi_d ();
475430
476431 ss.cal_stress (stress, ucell, this ->locpp , this ->ppcell , this ->pw_rhod ,
477- &ucell.symm , &this ->sf , &this ->kv , this ->pw_wfc , this ->__kspw_psi );
432+ &ucell.symm , &this ->sf , &this ->kv , this ->pw_wfc , this ->stp . psi_d );
478433
479434 // external stress
480435 double unit_transform = 0.0 ;
@@ -492,9 +447,8 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
492447 ESolver_KS<T, Device>::after_all_runners (ucell);
493448
494449 ModuleIO::ctrl_runner_pw<T, Device>(ucell, this ->pelec , this ->pw_wfc ,
495- this ->pw_rho , this ->pw_rhod , this ->chr , this ->kv , this ->psi ,
496- this ->kspw_psi , this ->__kspw_psi , this ->sf ,
497- this ->ppcell , this ->solvent , this ->ctx , this ->Pgrid , PARAM.inp );
450+ this ->pw_rho , this ->pw_rhod , this ->chr , this ->kv , this ->stp ,
451+ this ->sf , this ->ppcell , this ->solvent , this ->ctx , this ->Pgrid , PARAM.inp );
498452
499453 elecstate::teardown_estate_pw<T, Device>(this ->pelec , this ->vsep_cell );
500454
0 commit comments