Skip to content

Commit e0202e0

Browse files
authored
Refactor: remove init_wfc&mem_saver&out_wfc_pw&out_wfc_r of wavefunc in abacus (#5557)
* remove wavefunc. init_wfc mem_saver out_wfc_pw out_wfc_r in abacus * replace WFInit by PSIInit
1 parent 54b044b commit e0202e0

File tree

13 files changed

+135
-145
lines changed

13 files changed

+135
-145
lines changed

source/module_esolver/esolver_ks.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,6 @@ ESolver_KS<T, Device>::ESolver_KS()
6767
///----------------------------------------------------------
6868
p_chgmix = new Charge_Mixing();
6969
p_chgmix->set_rhopw(this->pw_rho, this->pw_rhod);
70-
71-
///----------------------------------------------------------
72-
/// wavefunc
73-
///----------------------------------------------------------
74-
this->wf.init_wfc = PARAM.inp.init_wfc;
75-
this->wf.mem_saver = PARAM.inp.mem_saver;
76-
this->wf.out_wfc_pw = PARAM.inp.out_wfc_pw;
77-
this->wf.out_wfc_r = PARAM.inp.out_wfc_r;
7870
}
7971

8072
//------------------------------------------------------------------------------

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
548548
// mohan move it outside 2011-01-13
549549
// first need to calculate the weight according to
550550
// electrons number.
551-
if (istep == 0 && this->wf.init_wfc == "file")
551+
if (istep == 0 && PARAM.inp.init_wfc == "file")
552552
{
553553
if (iter == 1)
554554
{

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(const Input_para& inp, UnitCell&
7676

7777
void ESolver_KS_LCAO_TDDFT::hamilt2density_single(const int istep, const int iter, const double ethr)
7878
{
79-
if (wf.init_wfc == "file")
79+
if (PARAM.inp.init_wfc == "file")
8080
{
8181
if (istep >= 1)
8282
{
@@ -256,7 +256,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
256256
const int nlocal = PARAM.globalv.nlocal;
257257

258258
// store wfc and Hk laststep
259-
if (istep >= (wf.init_wfc == "file" ? 0 : 1) && this->conv_esolver)
259+
if (istep >= (PARAM.inp.init_wfc == "file" ? 0 : 1) && this->conv_esolver)
260260
{
261261
if (this->psi_laststep == nullptr)
262262
{
@@ -311,7 +311,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
311311
}
312312

313313
// calculate energy density matrix for tddft
314-
if (istep >= (wf.init_wfc == "file" ? 0 : 2) && module_tddft::Evolve_elec::td_edm == 0)
314+
if (istep >= (PARAM.inp.init_wfc == "file" ? 0 : 2) && module_tddft::Evolve_elec::td_edm == 0)
315315
{
316316
elecstate::cal_edm_tddft(this->pv, this->pelec, this->kv, this->p_hamilt);
317317
}

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCel
181181
}
182182

183183
//! 7) prepare some parameters for electronic wave functions initilization
184-
this->p_wf_init = new psi::WFInit<T, Device>(PARAM.inp.init_wfc,
185-
PARAM.inp.ks_solver,
186-
PARAM.inp.basis_type,
187-
PARAM.inp.psi_initializer,
188-
&this->wf,
189-
this->pw_wfc);
184+
this->p_wf_init = new psi::PSIInit<T, Device>(PARAM.inp.init_wfc,
185+
PARAM.inp.ks_solver,
186+
PARAM.inp.basis_type,
187+
PARAM.inp.psi_initializer,
188+
&this->wf,
189+
this->pw_wfc);
190190
this->p_wf_init->prepare_init(&(this->sf),
191191
&ucell,
192192
1,
@@ -547,7 +547,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(const int istep, int& iter)
547547
}
548548

549549
// 4) Print out electronic wavefunctions
550-
if (this->wf.out_wfc_pw == 1 || this->wf.out_wfc_pw == 2)
550+
if (PARAM.inp.out_wfc_pw == 1 || PARAM.inp.out_wfc_pw == 2)
551551
{
552552
std::stringstream ssw;
553553
ssw << PARAM.globalv.global_out_dir << "WAVEFUNC";
@@ -573,7 +573,7 @@ void ESolver_KS_PW<T, Device>::after_scf(const int istep)
573573
ESolver_KS<T, Device>::after_scf(istep);
574574

575575
// 3) output wavefunctions
576-
if (this->wf.out_wfc_pw == 1 || this->wf.out_wfc_pw == 2)
576+
if (PARAM.inp.out_wfc_pw == 1 || PARAM.inp.out_wfc_pw == 2)
577577
{
578578
std::stringstream ssw;
579579
ssw << PARAM.globalv.global_out_dir << "WAVEFUNC";
@@ -821,7 +821,7 @@ void ESolver_KS_PW<T, Device>::after_all_runners()
821821
}
822822

823823
//! 6) Print out electronic wave functions in real space
824-
if (this->wf.out_wfc_r == 1) // Peize Lin add 2021.11.21
824+
if (PARAM.inp.out_wfc_r == 1) // Peize Lin add 2021.11.21
825825
{
826826
ModuleIO::write_psi_r_1(this->psi[0], this->pw_wfc, "wfc_realspace", true, this->kv);
827827
}

source/module_esolver/esolver_ks_pw.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>
5353
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* psi = nullptr;
5454

5555
// psi_initializer controller
56-
psi::WFInit<T, Device>* p_wf_init = nullptr;
56+
psi::PSIInit<T, Device>* p_wf_init = nullptr;
5757

5858
Device* ctx = {};
5959

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void ESolver_SDFT_PW<T, Device>::before_all_runners(const Input_para& inp, UnitC
5252
// 2) run "before_all_runners" in ESolver_KS
5353
ESolver_KS_PW<T, Device>::before_all_runners(inp, ucell);
5454

55-
// 9) initialize the stochastic wave functions
55+
// 3) initialize the stochastic wave functions
5656
this->stowf.init(&this->kv, this->pw_wfc->npwk_max);
5757
if (inp.nbands_sto != 0)
5858
{
@@ -75,7 +75,7 @@ void ESolver_SDFT_PW<T, Device>::before_all_runners(const Input_para& inp, UnitC
7575
}
7676
this->stowf.sync_chi0();
7777

78-
// 10) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}>
78+
// 4) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}>
7979
size_t size = stowf.chi0->size();
8080
this->stowf.shchi
8181
= new psi::Psi<T, Device>(this->kv.get_nks(), this->stowf.nchip_max, this->wf.npwx, this->kv.ngk.data());

source/module_esolver/lcao_others.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)
331331
this->pv,
332332
this->GG,
333333
PARAM.inp.out_wfc_pw,
334-
this->wf.out_wfc_r,
334+
PARAM.inp.out_wfc_r,
335335
this->kv,
336336
PARAM.inp.nelec,
337337
PARAM.inp.nbands_istate,
@@ -351,7 +351,7 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)
351351
this->pv,
352352
this->GK,
353353
PARAM.inp.out_wfc_pw,
354-
this->wf.out_wfc_r,
354+
PARAM.inp.out_wfc_r,
355355
this->kv,
356356
PARAM.inp.nelec,
357357
PARAM.inp.nbands_istate,

source/module_hamilt_pw/hamilt_pwdft/wavefunc.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ psi::Psi<std::complex<double>>* wavefunc::allocate(const int nkstot, const int n
5252
const int nks2 = nks;
5353

5454
psi::Psi<std::complex<double>>* psi_out = nullptr;
55-
if (PARAM.inp.calculation == "nscf" && this->mem_saver == 1)
55+
if (PARAM.inp.calculation == "nscf" && PARAM.inp.mem_saver == 1)
5656
{
5757
// initial psi rather than evc
5858
psi_out = new psi::Psi<std::complex<double>>(1, PARAM.inp.nbands, npwx * PARAM.globalv.npol, ngk);
@@ -140,11 +140,11 @@ void wavefunc::wfcinit(psi::Psi<std::complex<double>>* psi_in, ModulePW::PW_Basi
140140

141141
int wavefunc::get_starting_nw() const
142142
{
143-
if (init_wfc == "file")
143+
if (PARAM.inp.init_wfc == "file")
144144
{
145145
return PARAM.inp.nbands;
146146
}
147-
else if (init_wfc.substr(0, 6) == "atomic")
147+
else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
148148
{
149149
if (GlobalC::ucell.natomwfc >= PARAM.inp.nbands)
150150
{
@@ -164,7 +164,7 @@ int wavefunc::get_starting_nw() const
164164
}
165165
return std::max(GlobalC::ucell.natomwfc, PARAM.inp.nbands);
166166
}
167-
else if (init_wfc == "random")
167+
else if (PARAM.inp.init_wfc == "random")
168168
{
169169
if (PARAM.inp.test_wf)
170170
{
@@ -196,7 +196,7 @@ void diago_PAO_in_pw_k2(const int& ik,
196196
const int nbands = wvf.get_nbands();
197197
const int current_nbasis = wfc_basis->npwk[ik];
198198

199-
if (p_wf->init_wfc == "file")
199+
if (PARAM.inp.init_wfc == "file")
200200
{
201201
ModuleBase::ComplexMatrix wfcatom(nbands, nbasis);
202202
std::stringstream filename;
@@ -263,7 +263,7 @@ void diago_PAO_in_pw_k2(const int& ik,
263263
}
264264
*/
265265

266-
if (p_wf->init_wfc == "random" || (p_wf->init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
266+
if (PARAM.inp.init_wfc == "random" || (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
267267
{
268268
p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis);
269269

@@ -280,7 +280,7 @@ void diago_PAO_in_pw_k2(const int& ik,
280280
}
281281
}
282282
}
283-
else if (p_wf->init_wfc.substr(0, 6) == "atomic")
283+
else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
284284
{
285285
ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc
286286
if (PARAM.inp.test_wf) {
@@ -296,7 +296,7 @@ void diago_PAO_in_pw_k2(const int& ik,
296296
PARAM.globalv.nqx,
297297
PARAM.globalv.dq);
298298

299-
if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
299+
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
300300
{
301301
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
302302
}
@@ -355,7 +355,7 @@ void diago_PAO_in_pw_k2(const int& ik,
355355
const int nbands = wvf.get_nbands();
356356
const int current_nbasis = wfc_basis->npwk[ik];
357357

358-
if (p_wf->init_wfc == "file")
358+
if (PARAM.inp.init_wfc == "file")
359359
{
360360
ModuleBase::ComplexMatrix wfcatom(nbands, nbasis);
361361
std::stringstream filename;
@@ -420,7 +420,7 @@ void diago_PAO_in_pw_k2(const int& ik,
420420
assert(starting_nw > 0);
421421
std::vector<double> etatom(starting_nw, 0.0);
422422

423-
if (p_wf->init_wfc == "random" || (p_wf->init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
423+
if (PARAM.inp.init_wfc == "random" || (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
424424
{
425425
p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis);
426426
if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02
@@ -436,7 +436,7 @@ void diago_PAO_in_pw_k2(const int& ik,
436436
}
437437
}
438438
}
439-
else if (p_wf->init_wfc.substr(0, 6) == "atomic")
439+
else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
440440
{
441441
ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc
442442
if (PARAM.inp.test_wf)
@@ -453,7 +453,7 @@ void diago_PAO_in_pw_k2(const int& ik,
453453
PARAM.globalv.nqx,
454454
PARAM.globalv.dq);
455455

456-
if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
456+
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
457457
{
458458
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
459459
}
@@ -534,7 +534,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
534534
int starting_nw = nbands;
535535

536536
ModuleBase::ComplexMatrix wfcatom(nbands, nbasis);
537-
if (p_wf->init_wfc == "file")
537+
if (PARAM.inp.init_wfc == "file")
538538
{
539539
std::stringstream filename;
540540
int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot);
@@ -550,7 +550,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
550550
if (PARAM.inp.test_wf)
551551
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw);
552552

553-
if (p_wf->init_wfc.substr(0, 6) == "atomic")
553+
if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
554554
{
555555
p_wf->atomic_wfc(ik,
556556
current_nbasis,
@@ -560,7 +560,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
560560
GlobalC::ppcell.tab_at,
561561
PARAM.globalv.nqx,
562562
PARAM.globalv.dq);
563-
if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
563+
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
564564
{
565565
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
566566
}
@@ -571,7 +571,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
571571
//====================================================
572572
p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis);
573573
}
574-
else if (p_wf->init_wfc == "random")
574+
else if (PARAM.inp.init_wfc == "random")
575575
{
576576
p_wf->random(wfcatom.c, 0, nbands, ik, wfc_basis);
577577
}
@@ -638,7 +638,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
638638
int starting_nw = nbands;
639639

640640
ModuleBase::ComplexMatrix wfcatom(nbands, nbasis);
641-
if (p_wf->init_wfc == "file")
641+
if (PARAM.inp.init_wfc == "file")
642642
{
643643
std::stringstream filename;
644644
int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot);
@@ -653,7 +653,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
653653
wfcatom.create(starting_nw, nbasis); // added by zhengdy-soc
654654
if (PARAM.inp.test_wf)
655655
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw);
656-
if (p_wf->init_wfc.substr(0, 6) == "atomic")
656+
if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
657657
{
658658
p_wf->atomic_wfc(ik,
659659
current_nbasis,
@@ -663,7 +663,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
663663
GlobalC::ppcell.tab_at,
664664
PARAM.globalv.nqx,
665665
PARAM.globalv.dq);
666-
if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
666+
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
667667
{
668668
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
669669
}
@@ -674,7 +674,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
674674
//====================================================
675675
p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis);
676676
}
677-
else if (p_wf->init_wfc == "random")
677+
else if (PARAM.inp.init_wfc == "random")
678678
{
679679
p_wf->random(wfcatom.c, 0, nbands, ik, wfc_basis);
680680
}

source/module_hamilt_pw/hamilt_pwdft/wavefunc.h

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,50 +10,44 @@
1010

1111
class wavefunc : public WF_atomic
1212
{
13-
public:
14-
13+
public:
1514
wavefunc();
1615
~wavefunc();
1716

1817
// allocate memory
1918
psi::Psi<std::complex<double>>* allocate(const int nkstot, const int nks, const int* ngk, const int npwx);
2019

21-
int out_wfc_pw = 0; //qianrui modify 2020-10-19
22-
int out_wfc_r = 0; // Peize Lin add 2021.11.21
20+
int nkstot = 0; // total number of k-points for all pools
2321

24-
// init_wfc : "random",or "atomic" or "file"
25-
std::string init_wfc;
26-
int nkstot = 0; // total number of k-points for all pools
27-
int mem_saver = 0; // 1: save evc when doing nscf calculation.
2822
void wfcinit(psi::Psi<std::complex<double>>* psi_in, ModulePW::PW_Basis_K* wfc_basis);
29-
int get_starting_nw(void)const;
3023

31-
void init_after_vc(const int nks); //LiuXh 20180515
32-
};
24+
int get_starting_nw(void) const;
3325

26+
void init_after_vc(const int nks); // LiuXh 20180515
27+
};
3428

3529
namespace hamilt
3630
{
3731

38-
void diago_PAO_in_pw_k2(const int &ik,
39-
psi::Psi<std::complex<float>> &wvf,
40-
ModulePW::PW_Basis_K *wfc_basis,
41-
wavefunc *p_wf,
42-
hamilt::Hamilt<std::complex<float>> *phm_in = nullptr);
43-
void diago_PAO_in_pw_k2(const int &ik,
44-
psi::Psi<std::complex<double>> &wvf,
45-
ModulePW::PW_Basis_K *wfc_basis,
46-
wavefunc *p_wf,
47-
hamilt::Hamilt<std::complex<double>> *phm_in = nullptr);
48-
void diago_PAO_in_pw_k2(const int &ik, ModuleBase::ComplexMatrix &wvf, wavefunc *p_wf);
32+
void diago_PAO_in_pw_k2(const int& ik,
33+
psi::Psi<std::complex<float>>& wvf,
34+
ModulePW::PW_Basis_K* wfc_basis,
35+
wavefunc* p_wf,
36+
hamilt::Hamilt<std::complex<float>>* phm_in = nullptr);
37+
void diago_PAO_in_pw_k2(const int& ik,
38+
psi::Psi<std::complex<double>>& wvf,
39+
ModulePW::PW_Basis_K* wfc_basis,
40+
wavefunc* p_wf,
41+
hamilt::Hamilt<std::complex<double>>* phm_in = nullptr);
42+
void diago_PAO_in_pw_k2(const int& ik, ModuleBase::ComplexMatrix& wvf, wavefunc* p_wf);
4943

5044
template <typename FPTYPE, typename Device>
51-
void diago_PAO_in_pw_k2(const Device *ctx,
52-
const int &ik,
53-
psi::Psi<std::complex<FPTYPE>, Device> &wvf,
54-
ModulePW::PW_Basis_K *wfc_basis,
55-
wavefunc *p_wf,
56-
hamilt::Hamilt<std::complex<FPTYPE>, Device> *phm_in = nullptr);
57-
}
58-
59-
#endif //wavefunc
45+
void diago_PAO_in_pw_k2(const Device* ctx,
46+
const int& ik,
47+
psi::Psi<std::complex<FPTYPE>, Device>& wvf,
48+
ModulePW::PW_Basis_K* wfc_basis,
49+
wavefunc* p_wf,
50+
hamilt::Hamilt<std::complex<FPTYPE>, Device>* phm_in = nullptr);
51+
} // namespace hamilt
52+
53+
#endif // wavefunc

0 commit comments

Comments
 (0)