Skip to content

Commit e32341a

Browse files
authored
Fix: init_chg wfc support npsin = 4 now (#5166)
Fix write_wfc_pw and read_wfc_pw when nspin = 4
1 parent fb5beba commit e32341a

File tree

11 files changed

+371
-167
lines changed

11 files changed

+371
-167
lines changed

source/module_io/read_wfc_pw.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void ModuleIO::read_wfc_pw(const std::string& filename,
8282
ikstot = ik;
8383
#endif
8484

85-
npwtot *= PARAM.globalv.npol;
85+
int npwtot_npol = npwtot * PARAM.globalv.npol;
8686

8787

8888

@@ -178,7 +178,7 @@ void ModuleIO::read_wfc_pw(const std::string& filename,
178178
}
179179

180180
// read in miller index
181-
ModuleBase::Vector3<int>* miller = new ModuleBase::Vector3<int>[npwtot_in];
181+
ModuleBase::Vector3<int>* miller = new ModuleBase::Vector3<int>[npwtot];
182182
int* glo_order = nullptr;
183183
if (GlobalV::RANK_IN_POOL == 0)
184184
{
@@ -188,7 +188,7 @@ void ModuleIO::read_wfc_pw(const std::string& filename,
188188
else if (filetype == "dat")
189189
{
190190
rfs >> size;
191-
for (int i = 0; i < npwtot_in; ++i)
191+
for (int i = 0; i < npwtot; ++i)
192192
{
193193
rfs >> miller[i].x >> miller[i].y >> miller[i].z;
194194
}
@@ -201,7 +201,7 @@ void ModuleIO::read_wfc_pw(const std::string& filename,
201201
{
202202
glo_order[i] = -1;
203203
}
204-
for (int i = 0; i < npwtot_in / PARAM.globalv.npol; ++i)
204+
for (int i = 0; i < npwtot; ++i)
205205
{
206206
int index = (miller[i].x * ny + miller[i].y) * nz + miller[i].z;
207207
glo_order[index] = i;
@@ -221,7 +221,7 @@ void ModuleIO::read_wfc_pw(const std::string& filename,
221221
}
222222

223223
// read in wfc
224-
std::complex<double>* wfc_in = new std::complex<double>[npwtot_in];
224+
std::complex<double>* wfc_in = new std::complex<double>[npwtot_npol];
225225
for (int ib = 0; ib < nbands_in; ib++)
226226
{
227227
if (GlobalV::RANK_IN_POOL == 0)
@@ -232,7 +232,7 @@ void ModuleIO::read_wfc_pw(const std::string& filename,
232232
else if (filetype == "dat")
233233
{
234234
rfs >> size;
235-
for (int i = 0; i < npwtot_in; ++i)
235+
for (int i = 0; i < npwtot_npol; ++i)
236236
{
237237
rfs >> wfc_in[i];
238238
}
@@ -285,7 +285,7 @@ void ModuleIO::read_wfc_pw(const std::string& filename,
285285
{
286286
for (int i = 0; i < size; i++)
287287
{
288-
wfc_ip[i] = wfc_in[glo_order[ig_ip[i]] + npwtot_in / 2];
288+
wfc_ip[i] = wfc_in[glo_order[ig_ip[i]] + npwtot];
289289
}
290290
MPI_Send(wfc_ip, size, MPI_DOUBLE_COMPLEX, ip, ip + 2 * GlobalV::NPROC_IN_POOL, POOL_WORLD);
291291
}
@@ -305,7 +305,7 @@ void ModuleIO::read_wfc_pw(const std::string& filename,
305305
{
306306
for (int i = 0; i < pw_wfc->npwk[ik]; ++i)
307307
{
308-
wfc(ib, i + npwk_max) = wfc_in[glo_order[l2g_pw[i]] + npwtot_in / 2];
308+
wfc(ib, i + npwk_max) = wfc_in[glo_order[l2g_pw[i]] + npwtot];
309309
}
310310
}
311311
}
@@ -321,7 +321,7 @@ void ModuleIO::read_wfc_pw(const std::string& filename,
321321
{
322322
for (int i = 0; i < pw_wfc->npwk[ik]; ++i)
323323
{
324-
wfc(ib, i + npwk_max) = wfc_in[glo_order[l2g_pw[i]] + npwtot_in / 2];
324+
wfc(ib, i + npwk_max) = wfc_in[glo_order[l2g_pw[i]] + npwtot];
325325
}
326326
}
327327
#endif

source/module_io/read_wfc_to_rho.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "module_hamilt_pw/hamilt_pwdft/global.h"
66
#include "module_elecstate/module_charge/symmetry_rho.h"
77
#include "module_parameter/parameter.h"
8+
#include "module_elecstate/kernels/elecstate_op.h"
89

910
void ModuleIO::read_wfc_to_rho(const ModulePW::PW_Basis_K* pw_wfc,
1011
ModuleSymmetry::Symmetry& symm,
@@ -21,14 +22,14 @@ void ModuleIO::read_wfc_to_rho(const ModulePW::PW_Basis_K* pw_wfc,
2122
const int nbands = GlobalV::NBANDS;
2223
const int nspin = PARAM.inp.nspin;
2324

24-
const int npwk_max = pw_wfc->npwk_max;
25+
const int ng_npol = pw_wfc->npwk_max * PARAM.globalv.npol;
2526
const int nrxx = pw_wfc->nrxx;
2627
for (int is = 0; is < nspin; ++is)
2728
{
2829
ModuleBase::GlobalFunc::ZEROS(chg.rho[is], nrxx);
2930
}
3031

31-
ModuleBase::ComplexMatrix wfc_tmp(nbands, npwk_max);
32+
ModuleBase::ComplexMatrix wfc_tmp(nbands, ng_npol);
3233
std::vector<std::complex<double>> rho_tmp(nrxx);
3334

3435
// read occupation numbers
@@ -78,21 +79,44 @@ void ModuleIO::read_wfc_to_rho(const ModulePW::PW_Basis_K* pw_wfc,
7879
std::stringstream filename;
7980
filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ikstot + 1 << ".dat";
8081
ModuleIO::read_wfc_pw(filename.str(), pw_wfc, ik, nkstot, wfc_tmp);
81-
for (int ib = 0; ib < nbands; ++ib)
82+
if (PARAM.inp.nspin == 4)
8283
{
83-
const std::complex<double>* wfc_ib = wfc_tmp.c + ib * npwk_max;
84-
pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik);
85-
86-
const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega;
84+
std::vector<std::complex<double>> rho_tmp2(nrxx);
85+
for (int ib = 0; ib < nbands; ++ib)
86+
{
87+
const std::complex<double>* wfc_ib = wfc_tmp.c + ib * ng_npol;
88+
const std::complex<double>* wfc_ib2 = wfc_tmp.c + ib * ng_npol + ng_npol / 2;
89+
pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik);
90+
pw_wfc->recip2real(wfc_ib2, rho_tmp2.data(), ik);
91+
const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega;
8792

88-
if (w1 != 0.0)
93+
if (w1 != 0.0)
94+
{
95+
base_device::DEVICE_CPU* ctx = nullptr;
96+
elecstate::elecstate_pw_op<double, base_device::DEVICE_CPU>()(ctx,
97+
PARAM.globalv.domag,
98+
PARAM.globalv.domag_z,
99+
nrxx,
100+
w1,
101+
chg.rho,
102+
rho_tmp.data(),
103+
rho_tmp2.data());
104+
}
105+
}
106+
}
107+
else
108+
{
109+
for (int ib = 0; ib < nbands; ++ib)
89110
{
90-
#ifdef _OPENMP
91-
#pragma omp parallel for
92-
#endif
93-
for (int ir = 0; ir < nrxx; ir++)
111+
const std::complex<double>* wfc_ib = wfc_tmp.c + ib * ng_npol;
112+
pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik);
113+
114+
const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega;
115+
116+
if (w1 != 0.0)
94117
{
95-
chg.rho[is][ir] += w1 * std::norm(rho_tmp[ir]);
118+
base_device::DEVICE_CPU* ctx = nullptr;
119+
elecstate::elecstate_pw_op<double, base_device::DEVICE_CPU>()(ctx, is, nrxx, w1, chg.rho, rho_tmp.data());
96120
}
97121
}
98122
}

0 commit comments

Comments
 (0)