Skip to content

Commit 57bdc5a

Browse files
committed
change module_psi/wf_atomic.cpp
1 parent 7f0807d commit 57bdc5a

File tree

7 files changed

+91
-81
lines changed

7 files changed

+91
-81
lines changed

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
219219
this->kv.ngk.data(),
220220
this->pw_wfc->npwk_max,
221221
&this->sf,
222-
&this->ppcell);
222+
&this->ppcell,
223+
ucell);
223224

224225
this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single"
225226
? new psi::Psi<T, Device>(this->psi[0])
@@ -257,7 +258,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
257258

258259
this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma);
259260

260-
this->p_wf_init->make_table(this->kv.get_nks(), &this->sf, &this->ppcell);
261+
this->p_wf_init->make_table(this->kv.get_nks(), &this->sf, &this->ppcell,ucell);
261262
}
262263
if (ucell.ionic_position_updated)
263264
{

source/module_psi/psi_init.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
9393
const int* ngk,
9494
const int npwx,
9595
Structure_Factor* p_sf,
96-
pseudopot_cell_vnl* p_ppcell)
96+
pseudopot_cell_vnl* p_ppcell,
97+
const UnitCell& ucell)
9798
{
9899
// allocate memory for std::complex<double> datatype psi
99100
// New psi initializer in ABACUS, Developer's note:
@@ -126,7 +127,7 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
126127
// however, init_at_1 does not actually initialize the psi, instead, it is a
127128
// function to calculate a interpolate table saving overlap intergral or say
128129
// Spherical Bessel Transform of atomic orbitals.
129-
this->wf_old.init_at_1(p_sf, &p_ppcell->tab_at);
130+
this->wf_old.init_at_1(ucell,p_sf, &p_ppcell->tab_at);
130131
// similarly, wfcinit not really initialize any wavefunction, instead, it initialize
131132
// the mapping from ixy, the 1d flattened index of point on fft grid (x, y) plane,
132133
// to the index of "stick", composed of grid points.
@@ -135,15 +136,18 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
135136
}
136137

137138
template <typename T, typename Device>
138-
void PSIInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell)
139+
void PSIInit<T, Device>::make_table(const int nks,
140+
Structure_Factor* p_sf,
141+
pseudopot_cell_vnl* p_ppcell,
142+
const UnitCell& ucell)
139143
{
140144
if (this->use_psiinitializer)
141145
{
142146
} // do not need to do anything because the interpolate table is unchanged
143147
else // old initialization method, used in EXX calculation
144148
{
145149
this->wf_old.init_after_vc(nks); // reallocate wanf2, the planewave expansion of lcao
146-
this->wf_old.init_at_1(p_sf, &p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
150+
this->wf_old.init_at_1(ucell,p_sf, &p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
147151
}
148152
}
149153

@@ -279,8 +283,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
279283
&this->wf_old,
280284
nlpp.tab_at,
281285
nlpp.lmaxkb,
282-
ucell.natomwfc,
283-
ucell.lmax_ppwf,
286+
ucell,
284287
p_hamilt);
285288
}
286289
}
@@ -297,8 +300,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
297300
&this->wf_old,
298301
nlpp.tab_at,
299302
nlpp.lmaxkb,
300-
ucell.natomwfc,
301-
ucell.lmax_ppwf,
303+
ucell,
302304
p_hamilt);
303305
}
304306
}

source/module_psi/psi_init.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@ class PSIInit
3636
const int* ngk, //< number of G-vectors in the current pool
3737
const int npwx, //< max number of plane waves of all pools
3838
Structure_Factor* p_sf, //< structure factor
39-
pseudopot_cell_vnl* p_ppcell); //< nonlocal pseudopotential
39+
pseudopot_cell_vnl* p_ppcell, //< nonlocal pseudopotential
40+
const UnitCell& ucell); //< unit cell
4041

4142
// make interpolate table
42-
void make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell);
43+
void make_table(const int nks,
44+
Structure_Factor* p_sf,
45+
pseudopot_cell_vnl* p_ppcell,
46+
const UnitCell& ucell);
4347

4448
//------------------------ only for psi_initializer --------------------
4549
/**

source/module_psi/wavefunc.cpp

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
194194
wavefunc* p_wf,
195195
const ModuleBase::realArray& tab_at,
196196
const int& lmaxkb,
197-
const int natomwfc,
198-
const int lmax_ppwf,
197+
const UnitCell& ucell,
199198
hamilt::Hamilt<std::complex<float>, base_device::DEVICE_CPU>* phm_in)
200199
{
201200
// TODO float func
@@ -209,8 +208,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
209208
wavefunc* p_wf,
210209
const ModuleBase::realArray& tab_at,
211210
const int& lmaxkb,
212-
const int natomwfc,
213-
const int lmax_ppwf,
211+
const UnitCell& ucell,
214212
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>* phm_in)
215213
{
216214
ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2");
@@ -256,7 +254,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
256254
}
257255
}
258256
else if (PARAM.inp.init_wfc == "random"
259-
|| (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && natomwfc == 0))
257+
|| (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && ucell.natomwfc == 0))
260258
{
261259
p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis);
262260

@@ -277,7 +275,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
277275
}
278276
else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
279277
{
280-
const int starting_nw = p_wf->get_starting_nw(natomwfc);
278+
const int starting_nw = p_wf->get_starting_nw(ucell.natomwfc);
281279
if (starting_nw == 0)
282280
{
283281
return;
@@ -291,9 +289,10 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
291289
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw);
292290
}
293291

294-
p_wf->atomic_wfc(ik,
292+
p_wf->atomic_wfc(ucell,
293+
ik,
295294
current_nbasis,
296-
lmax_ppwf,
295+
ucell.lmax_ppwf,
297296
lmaxkb,
298297
wfc_basis,
299298
wfcatom,
@@ -302,7 +301,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
302301
PARAM.globalv.dq);
303302

304303
if (PARAM.inp.init_wfc == "atomic+random"
305-
&& starting_nw == natomwfc) // added by qianrui 2021-5-16
304+
&& starting_nw == ucell.natomwfc) // added by qianrui 2021-5-16
306305
{
307306
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
308307
}
@@ -311,7 +310,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
311310
// If not enough atomic wfc are available, complete
312311
// with random wfcs
313312
//====================================================
314-
p_wf->random(wfcatom.c, natomwfc, nbands, ik, wfc_basis);
313+
p_wf->random(wfcatom.c, ucell.natomwfc, nbands, ik, wfc_basis);
315314

316315
// (7) Diago with cg method.
317316
// if(GlobalV::DIAGO_TYPE == "cg") xiaohui modify 2013-09-02
@@ -355,8 +354,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
355354
wavefunc* p_wf,
356355
const ModuleBase::realArray& tab_at,
357356
const int& lmaxkb,
358-
const int natomwfc,
359-
const int lmax_ppwf,
357+
const UnitCell& ucell,
360358
hamilt::Hamilt<std::complex<float>, base_device::DEVICE_GPU>* phm_in)
361359
{
362360
// TODO float
@@ -370,8 +368,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
370368
wavefunc* p_wf,
371369
const ModuleBase::realArray& tab_at,
372370
const int& lmaxkb,
373-
const int natomwfc,
374-
const int lmax_ppwf,
371+
const UnitCell& ucell,
375372
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_GPU>* phm_in)
376373
{
377374
ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2");
@@ -399,7 +396,8 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
399396
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw);
400397
if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
401398
{
402-
p_wf->atomic_wfc(ik,
399+
p_wf->atomic_wfc(ucell,
400+
ik,
403401
current_nbasis,
404402
lmax_ppwf,
405403
lmaxkb,
@@ -409,7 +407,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
409407
PARAM.globalv.nqx,
410408
PARAM.globalv.dq);
411409
if (PARAM.inp.init_wfc == "atomic+random"
412-
&& starting_nw == natomwfc) // added by qianrui 2021-5-16
410+
&& starting_nw == ucell.natomwfc) // added by qianrui 2021-5-16
413411
{
414412
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
415413
}
@@ -418,7 +416,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
418416
// If not enough atomic wfc are available, complete
419417
// with random wfcs
420418
//====================================================
421-
p_wf->random(wfcatom.c, natomwfc, nbands, ik, wfc_basis);
419+
p_wf->random(wfcatom.c, ucell.natomwfc, nbands, ik, wfc_basis);
422420
}
423421
else if (PARAM.inp.init_wfc == "random")
424422
{

source/module_psi/wavefunc.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ void diago_PAO_in_pw_k2(const Device* ctx,
5757
wavefunc* p_wf,
5858
const ModuleBase::realArray& tab_at,
5959
const int& lmaxkb,
60-
const int natomwfc,
61-
const int lmax_ppwf,
60+
const UnitCell& ucell,
6261
hamilt::Hamilt<std::complex<FPTYPE>, Device>* phm_in = nullptr);
6362
} // namespace hamilt
6463

0 commit comments

Comments
 (0)