Skip to content

Commit 17c48f1

Browse files
A-006Fisherd99
authored andcommitted
Refactor:Remove GlobalC::ucell in module_lr,module_psi (deepmodeling#5691)
* change ucell in module_lr * change ucell in module_psi/wavefunc.cpp * change module_psi/wf_atomic.cpp * change module_pwdft/structure_factor_k.cpp * fix bug in wavefunc
1 parent 3d7ac14 commit 17c48f1

File tree

14 files changed

+141
-112
lines changed

14 files changed

+141
-112
lines changed

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
219219
this->kv.ngk.data(),
220220
this->pw_wfc->npwk_max,
221221
&this->sf,
222-
&this->ppcell);
222+
&this->ppcell,
223+
ucell);
223224

224225
this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single"
225226
? new psi::Psi<T, Device>(this->psi[0])
@@ -257,7 +258,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
257258

258259
this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma);
259260

260-
this->p_wf_init->make_table(this->kv.get_nks(), &this->sf, &this->ppcell);
261+
this->p_wf_init->make_table(this->kv.get_nks(), &this->sf, &this->ppcell,ucell);
261262
}
262263
if (ucell.ionic_position_updated)
263264
{
@@ -373,6 +374,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
373374
this->kspw_psi,
374375
this->p_hamilt,
375376
this->ppcell,
377+
ucell,
376378
GlobalV::ofs_running,
377379
this->already_initpsi);
378380

source/module_hamilt_pw/hamilt_pwdft/structure_factor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void Structure_Factor::setup_structure_factor(const UnitCell* Ucell, const Modul
6161
ModuleBase::TITLE("PW_Basis","setup_structure_factor");
6262
ModuleBase::timer::tick("PW_Basis","setup_struc_factor");
6363
const std::complex<double> ci_tpi = ModuleBase::NEG_IMAG_UNIT * ModuleBase::TWO_PI;
64-
64+
this->ucell = Ucell;
6565
this->strucFac.create(Ucell->ntype, rho_basis->npw);
6666
ModuleBase::Memory::record("SF::strucFac", sizeof(std::complex<double>) * Ucell->ntype*rho_basis->npw);
6767

source/module_hamilt_pw/hamilt_pwdft/structure_factor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class Structure_Factor
5353
ModuleBase::Vector3<double> q);
5454

5555
private:
56+
const UnitCell* ucell;
5657
std::complex<float> * c_eigts1 = nullptr, * c_eigts2 = nullptr, * c_eigts3 = nullptr;
5758
std::complex<double> * z_eigts1 = nullptr, * z_eigts2 = nullptr, * z_eigts3 = nullptr;
5859
const ModulePW::PW_Basis* rho_basis = nullptr;

source/module_hamilt_pw/hamilt_pwdft/structure_factor_k.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ std::complex<double>* Structure_Factor::get_sk(const int ik,
1010
const ModulePW::PW_Basis_K* wfc_basis) const
1111
{
1212
ModuleBase::timer::tick("Structure_Factor", "get_sk");
13-
const double arg = (wfc_basis->kvec_c[ik] * GlobalC::ucell.atoms[it].tau[ia]) * ModuleBase::TWO_PI;
13+
const double arg = (wfc_basis->kvec_c[ik] * ucell->atoms[it].tau[ia]) * ModuleBase::TWO_PI;
1414
const std::complex<double> kphase = std::complex<double>(cos(arg), -sin(arg));
1515
const int npw = wfc_basis->npwk[ik];
1616
std::complex<double> *sk = new std::complex<double>[npw];
@@ -26,19 +26,22 @@ std::complex<double>* Structure_Factor::get_sk(const int ik,
2626
const int ixy = wfc_basis->is2fftixy[is];
2727
int ix = ixy / wfc_basis->fftny;
2828
int iy = ixy % wfc_basis->fftny;
29-
if (ix >= int(nx / 2) + 1) {
29+
if (ix >= int(nx / 2) + 1)
30+
{
3031
ix -= nx;
31-
}
32-
if (iy >= int(ny / 2) + 1) {
32+
}
33+
if (iy >= int(ny / 2) + 1)
34+
{
3335
iy -= ny;
34-
}
35-
if (iz >= int(nz / 2) + 1) {
36+
}
37+
if (iz >= int(nz / 2) + 1)
38+
{
3639
iz -= nz;
37-
}
40+
}
3841
ix += this->rho_basis->nx;
3942
iy += this->rho_basis->ny;
4043
iz += this->rho_basis->nz;
41-
const int iat = GlobalC::ucell.itia2iat(it, ia);
44+
const int iat = ucell->itia2iat(it, ia);
4245
sk[igl] = kphase * this->eigts1(iat, ix) * this->eigts2(iat, iy) * this->eigts3(iat, iz);
4346
}
4447
ModuleBase::timer::tick("Structure_Factor", "get_sk");
@@ -66,33 +69,33 @@ void Structure_Factor::get_sk(Device* ctx,
6669

6770
int iat = 0, _npw = wfc_basis->npwk[ik], eigts1_nc = this->eigts1.nc, eigts2_nc = this->eigts2.nc,
6871
eigts3_nc = this->eigts3.nc;
69-
int *igl2isz = nullptr, *is2fftixy = nullptr, *atom_na = nullptr, *h_atom_na = new int[GlobalC::ucell.ntype];
70-
FPTYPE *atom_tau = nullptr, *h_atom_tau = new FPTYPE[GlobalC::ucell.nat * 3], *kvec = wfc_basis->get_kvec_c_data<FPTYPE>();
72+
int *igl2isz = nullptr, *is2fftixy = nullptr, *atom_na = nullptr, *h_atom_na = new int[ucell->ntype];
73+
FPTYPE *atom_tau = nullptr, *h_atom_tau = new FPTYPE[ucell->nat * 3], *kvec = wfc_basis->get_kvec_c_data<FPTYPE>();
7174
std::complex<FPTYPE> *eigts1 = this->get_eigts1_data<FPTYPE>(), *eigts2 = this->get_eigts2_data<FPTYPE>(),
7275
*eigts3 = this->get_eigts3_data<FPTYPE>();
73-
for (int it = 0; it < GlobalC::ucell.ntype; it++)
76+
for (int it = 0; it < ucell->ntype; it++)
7477
{
75-
h_atom_na[it] = GlobalC::ucell.atoms[it].na;
78+
h_atom_na[it] = ucell->atoms[it].na;
7679
}
7780
#ifdef _OPENMP
7881
#pragma omp parallel for
7982
#endif
80-
for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
83+
for (int iat = 0; iat < ucell->nat; iat++)
8184
{
82-
int it = GlobalC::ucell.iat2it[iat];
83-
int ia = GlobalC::ucell.iat2ia[iat];
84-
auto *tau = reinterpret_cast<double *>(GlobalC::ucell.atoms[it].tau.data());
85+
int it = ucell->iat2it[iat];
86+
int ia = ucell->iat2ia[iat];
87+
auto *tau = reinterpret_cast<double *>(ucell->atoms[it].tau.data());
8588
h_atom_tau[iat * 3 + 0] = static_cast<FPTYPE>(tau[ia * 3 + 0]);
8689
h_atom_tau[iat * 3 + 1] = static_cast<FPTYPE>(tau[ia * 3 + 1]);
8790
h_atom_tau[iat * 3 + 2] = static_cast<FPTYPE>(tau[ia * 3 + 2]);
8891
}
8992
if (device == base_device::GpuDevice)
9093
{
91-
resmem_int_op()(ctx, atom_na, GlobalC::ucell.ntype);
92-
syncmem_int_op()(ctx, cpu_ctx, atom_na, h_atom_na, GlobalC::ucell.ntype);
94+
resmem_int_op()(ctx, atom_na, ucell->ntype);
95+
syncmem_int_op()(ctx, cpu_ctx, atom_na, h_atom_na, ucell->ntype);
9396

94-
resmem_var_op()(ctx, atom_tau, GlobalC::ucell.nat * 3);
95-
syncmem_var_op()(ctx, cpu_ctx, atom_tau, h_atom_tau, GlobalC::ucell.nat * 3);
97+
resmem_var_op()(ctx, atom_tau, ucell->nat * 3);
98+
syncmem_var_op()(ctx, cpu_ctx, atom_tau, h_atom_tau, ucell->nat * 3);
9699

97100
igl2isz = wfc_basis->d_igl2isz_k;
98101
is2fftixy = wfc_basis->d_is2fftixy;
@@ -107,7 +110,7 @@ void Structure_Factor::get_sk(Device* ctx,
107110

108111
cal_sk_op()(ctx,
109112
ik,
110-
GlobalC::ucell.ntype,
113+
ucell->ntype,
111114
wfc_basis->nx,
112115
wfc_basis->ny,
113116
wfc_basis->nz,
@@ -152,7 +155,7 @@ std::complex<double>* Structure_Factor::get_skq(int ik,
152155
for (int ig = 0; ig < npw; ig++)
153156
{
154157
ModuleBase::Vector3<double> qkq = wfc_basis->getgpluskcar(ik, ig) + q;
155-
double arg = (qkq * GlobalC::ucell.atoms[it].tau[ia]) * ModuleBase::TWO_PI;
158+
double arg = (qkq * ucell->atoms[it].tau[ia]) * ModuleBase::TWO_PI;
156159
skq[ig] = std::complex<double>(cos(arg), -sin(arg));
157160
}
158161

source/module_hsolver/test/hsolver_pw_sup.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ void diago_PAO_in_pw_k2(
192192
wavefunc* p_wf,
193193
const ModuleBase::realArray& tab_at,
194194
const int& lmaxkb,
195+
const UnitCell& ucell,
195196
hamilt::Hamilt<std::complex<float>, base_device::DEVICE_CPU>* phm_in) {
196197
for (int i = 0; i < wvf.size(); i++) {
197198
wvf.get_pointer()[i] = std::complex<float>((float)i + 1, 0);
@@ -207,6 +208,7 @@ void diago_PAO_in_pw_k2(
207208
wavefunc* p_wf,
208209
const ModuleBase::realArray& tab_at,
209210
const int& lmaxkb,
211+
const UnitCell& ucell,
210212
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>* phm_in) {
211213
for (int i = 0; i < wvf.size(); i++) {
212214
wvf.get_pointer()[i] = std::complex<double>((double)i + 1, 0);

source/module_lr/esolver_lrtd_lcao.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
255255
// necessary steps in ESolver_KS::before_all_runners : symmetry and k-points
256256
if (ModuleSymmetry::Symmetry::symm_flag == 1)
257257
{
258-
GlobalC::ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running);
258+
ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running);
259259
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SYMMETRY");
260260
}
261261
this->kv.set(ucell,ucell.symm, PARAM.inp.kpoint_file, PARAM.inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
@@ -318,12 +318,12 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
318318
this->init_pot(chg_gs);
319319

320320
// search adjacent atoms and init Gint
321-
std::cout << "ucell.infoNL.get_rcutmax_Beta(): " << GlobalC::ucell.infoNL.get_rcutmax_Beta() << std::endl;
321+
std::cout << "ucell.infoNL.get_rcutmax_Beta(): " << ucell.infoNL.get_rcutmax_Beta() << std::endl;
322322
double search_radius = -1.0;
323323
search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running,
324324
PARAM.inp.out_level,
325325
orb.get_rcutmax_Phi(),
326-
GlobalC::ucell.infoNL.get_rcutmax_Beta(),
326+
ucell.infoNL.get_rcutmax_Beta(),
327327
PARAM.globalv.gamma_only_local);
328328
atom_arrange::search(PARAM.inp.search_pbc,
329329
GlobalV::ofs_running,
@@ -341,7 +341,7 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
341341
std::vector<std::vector<double>> dpsi_u;
342342
std::vector<std::vector<double>> d2psi_u;
343343

344-
Gint_Tools::init_orb(dr_uniform, rcuts, GlobalC::ucell, orb, psi_u, dpsi_u, d2psi_u);
344+
Gint_Tools::init_orb(dr_uniform, rcuts, ucell, orb, psi_u, dpsi_u, d2psi_u);
345345
this->gt_.set_pbc_grid(this->pw_rho->nx,
346346
this->pw_rho->ny,
347347
this->pw_rho->nz,
@@ -357,7 +357,7 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
357357
this->pw_rho->ny,
358358
this->pw_rho->nplane,
359359
this->pw_rho->startz_current,
360-
GlobalC::ucell,
360+
ucell,
361361
GlobalC::GridD,
362362
dr_uniform,
363363
rcuts,

source/module_lr/operator_casida/operator_lr_exx.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ namespace LR
6565
for (auto cell : this->BvK_cells)
6666
{
6767
std::complex<double> frac = RI::Global_Func::convert<std::complex<double>>(std::exp(
68-
-ModuleBase::TWO_PI * ModuleBase::IMAG_UNIT * (this->kv.kvec_c.at(ik) * (RI_Util::array3_to_Vector3(cell) * GlobalC::ucell.latvec))));
68+
-ModuleBase::TWO_PI * ModuleBase::IMAG_UNIT * (this->kv.kvec_c.at(ik) * (RI_Util::array3_to_Vector3(cell) * ucell.latvec))));
6969
for (int it1 = 0;it1 < ucell.ntype;++it1)
7070
for (int ia1 = 0; ia1 < ucell.atoms[it1].na; ++ia1)
7171
for (int it2 = 0;it2 < ucell.ntype;++it2)

source/module_lr/operator_casida/operator_lr_hxc.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ namespace LR
6868

6969
// 3. v_hxc = f_hxc * rho_trans
7070
ModuleBase::matrix vr_hxc(1, nrxx); //grid
71-
this->pot.lock()->cal_v_eff(rho_trans, GlobalC::ucell, vr_hxc, ispin_ks);
71+
this->pot.lock()->cal_v_eff(rho_trans, ucell, vr_hxc, ispin_ks);
7272
LR_Util::_deallocate_2order_nested_ptr(rho_trans, 1);
7373

7474
// 4. V^{Hxc}_{\mu,\nu}=\int{dr} \phi_\mu(r) v_{Hxc}(r) \phi_\mu(r)
7575
Gint_inout inout_vlocal(vr_hxc.c, 0, Gint_Tools::job_type::vlocal);
7676
this->gint->get_hRGint()->set_zero();
7777
this->gint->cal_gint(&inout_vlocal);
7878
this->hR->set_zero(); // clear hR for each bands
79-
this->gint->transfer_pvpR(&*this->hR, &GlobalC::ucell); //grid to 2d block
79+
this->gint->transfer_pvpR(&*this->hR, &ucell); //grid to 2d block
8080
ModuleBase::timer::tick("OperatorLRHxc", "grid_calculation");
8181
}
8282

@@ -88,7 +88,7 @@ namespace LR
8888

8989
elecstate::DensityMatrix<std::complex<double>, double> DM_trans_real_imag(&pmat, 1, kv.kvec_d, kv.get_nks() / nspin);
9090
DM_trans_real_imag.init_DMR(*this->hR);
91-
hamilt::HContainer<double> HR_real_imag(GlobalC::ucell, &this->pmat);
91+
hamilt::HContainer<double> HR_real_imag(ucell, &this->pmat);
9292
LR_Util::initialize_HR<std::complex<double>, double>(HR_real_imag, ucell, gd, orb_cutoff_);
9393

9494
auto dmR_to_hR = [&, this](const char& type) -> void
@@ -111,7 +111,7 @@ namespace LR
111111

112112
// 3. v_hxc = f_hxc * rho_trans
113113
ModuleBase::matrix vr_hxc(1, nrxx); //grid
114-
this->pot.lock()->cal_v_eff(rho_trans, GlobalC::ucell, vr_hxc, ispin_ks);
114+
this->pot.lock()->cal_v_eff(rho_trans, ucell, vr_hxc, ispin_ks);
115115
// print_grid_nonzero(vr_hxc.c, this->poticab->nrxx, 10, "vr_hxc");
116116

117117
LR_Util::_deallocate_2order_nested_ptr(rho_trans, 1);
@@ -123,9 +123,9 @@ namespace LR
123123

124124
// LR_Util::print_HR(*this->gint->get_hRGint(), this->ucell.nat, "VR(grid)");
125125
HR_real_imag.set_zero();
126-
this->gint->transfer_pvpR(&HR_real_imag, &GlobalC::ucell, &GlobalC::GridD);
126+
this->gint->transfer_pvpR(&HR_real_imag, &ucell, &GlobalC::GridD);
127127
// LR_Util::print_HR(HR_real_imag, this->ucell.nat, "VR(real, 2d)");
128-
LR_Util::set_HR_real_imag_part(HR_real_imag, *this->hR, GlobalC::ucell.nat, type);
128+
LR_Util::set_HR_real_imag_part(HR_real_imag, *this->hR, ucell.nat, type);
129129
};
130130
this->hR->set_zero();
131131
dmR_to_hR('R'); //real

source/module_psi/psi_init.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
9393
const int* ngk,
9494
const int npwx,
9595
Structure_Factor* p_sf,
96-
pseudopot_cell_vnl* p_ppcell)
96+
pseudopot_cell_vnl* p_ppcell,
97+
const UnitCell& ucell)
9798
{
9899
// allocate memory for std::complex<double> datatype psi
99100
// New psi initializer in ABACUS, Developer's note:
@@ -126,7 +127,7 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
126127
// however, init_at_1 does not actually initialize the psi, instead, it is a
127128
// function to calculate a interpolate table saving overlap intergral or say
128129
// Spherical Bessel Transform of atomic orbitals.
129-
this->wf_old.init_at_1(p_sf, &p_ppcell->tab_at);
130+
this->wf_old.init_at_1(ucell,p_sf, &p_ppcell->tab_at);
130131
// similarly, wfcinit not really initialize any wavefunction, instead, it initialize
131132
// the mapping from ixy, the 1d flattened index of point on fft grid (x, y) plane,
132133
// to the index of "stick", composed of grid points.
@@ -135,15 +136,18 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
135136
}
136137

137138
template <typename T, typename Device>
138-
void PSIInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell)
139+
void PSIInit<T, Device>::make_table(const int nks,
140+
Structure_Factor* p_sf,
141+
pseudopot_cell_vnl* p_ppcell,
142+
const UnitCell& ucell)
139143
{
140144
if (this->use_psiinitializer)
141145
{
142146
} // do not need to do anything because the interpolate table is unchanged
143147
else // old initialization method, used in EXX calculation
144148
{
145149
this->wf_old.init_after_vc(nks); // reallocate wanf2, the planewave expansion of lcao
146-
this->wf_old.init_at_1(p_sf, &p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
150+
this->wf_old.init_at_1(ucell,p_sf, &p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
147151
}
148152
}
149153

@@ -152,6 +156,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
152156
psi::Psi<T, Device>* kspw_psi,
153157
hamilt::Hamilt<T, Device>* p_hamilt,
154158
const pseudopot_cell_vnl& nlpp,
159+
const UnitCell& ucell,
155160
std::ofstream& ofs_running,
156161
const bool is_already_initpsi)
157162
{
@@ -278,6 +283,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
278283
&this->wf_old,
279284
nlpp.tab_at,
280285
nlpp.lmaxkb,
286+
ucell,
281287
p_hamilt);
282288
}
283289
}
@@ -294,6 +300,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
294300
&this->wf_old,
295301
nlpp.tab_at,
296302
nlpp.lmaxkb,
303+
ucell,
297304
p_hamilt);
298305
}
299306
}

source/module_psi/psi_init.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@ class PSIInit
3636
const int* ngk, //< number of G-vectors in the current pool
3737
const int npwx, //< max number of plane waves of all pools
3838
Structure_Factor* p_sf, //< structure factor
39-
pseudopot_cell_vnl* p_ppcell); //< nonlocal pseudopotential
39+
pseudopot_cell_vnl* p_ppcell, //< nonlocal pseudopotential
40+
const UnitCell& ucell); //< unit cell
4041

4142
// make interpolate table
42-
void make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell);
43+
void make_table(const int nks,
44+
Structure_Factor* p_sf,
45+
pseudopot_cell_vnl* p_ppcell,
46+
const UnitCell& ucell);
4347

4448
//------------------------ only for psi_initializer --------------------
4549
/**
@@ -54,6 +58,7 @@ class PSIInit
5458
psi::Psi<T, Device>* kspw_psi,
5559
hamilt::Hamilt<T, Device>* p_hamilt,
5660
const pseudopot_cell_vnl& nlpp,
61+
const UnitCell& ucell,
5762
std::ofstream& ofs_running,
5863
const bool is_already_initpsi);
5964

0 commit comments

Comments
 (0)