Skip to content

Commit 3931789

Browse files
committed
Fix incorrect pchg when if_separate_k is false due to lack of reduce across diff pools
1 parent d50bce4 commit 3931789

File tree

4 files changed

+183
-140
lines changed

4 files changed

+183
-140
lines changed

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -728,9 +728,6 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
728728
this->pw_rhod->ny,
729729
this->pw_rhod->nz,
730730
this->pw_rhod->nxyz,
731-
this->kv.get_nks(),
732-
this->kv.isk,
733-
this->kv.wk,
734731
this->chr.ngmc,
735732
&ucell,
736733
this->psi,
@@ -739,7 +736,11 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
739736
this->ctx,
740737
this->Pgrid,
741738
PARAM.globalv.global_out_dir,
742-
PARAM.inp.if_separate_k);
739+
PARAM.inp.if_separate_k,
740+
this->kv,
741+
GlobalV::KPAR,
742+
GlobalV::MY_POOL,
743+
&this->chr);
743744
}
744745

745746
// tmp 2025-05-17, mohan note
@@ -970,7 +971,6 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
970971
this->pw_rhod->nxyz,
971972
&ucell,
972973
this->psi,
973-
this->pw_rhod,
974974
this->pw_wfc,
975975
this->ctx,
976976
this->Pgrid,

source/module_io/get_pchg_pw.h

Lines changed: 143 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "module_basis/module_pw/pw_basis_k.h"
99
#include "module_cell/unitcell.h"
1010
#include "module_elecstate/elecstate.h"
11+
#include "module_elecstate/module_charge/charge.h"
1112
#include "module_elecstate/module_charge/symmetry_rho.h"
1213
#include "module_hamilt_pw/hamilt_pwdft/parallel_grid.h"
1314
#include "module_psi/psi.h"
@@ -25,9 +26,6 @@ void get_pchg_pw(const std::vector<int>& out_pchg,
2526
const int ny,
2627
const int nz,
2728
const int nxyz,
28-
const int nks,
29-
const std::vector<int>& isk,
30-
const std::vector<double>& wk,
3129
const int ngmc,
3230
UnitCell* ucell,
3331
const psi::Psi<std::complex<double>>* psi,
@@ -36,165 +34,183 @@ void get_pchg_pw(const std::vector<int>& out_pchg,
3634
const Device* ctx,
3735
const Parallel_Grid& pgrid,
3836
const std::string& global_out_dir,
39-
const bool if_separate_k)
37+
const bool if_separate_k,
38+
const K_Vectors& kv,
39+
const int kpar,
40+
const int my_pool,
41+
const Charge* chg) // Charge class is needed for the charge density reduce
4042
{
41-
// bands_picked is a vector of 0s and 1s, where 1 means the band is picked to output
42-
std::vector<int> bands_picked(nbands, 0);
43+
// Get necessary parameters from kv
44+
const int nks = kv.get_nks(); // current process pool k-point count
45+
const int nkstot = kv.get_nkstot(); // total k-point count
4346

44-
// Check if length of out_pchg is valid
45-
if (static_cast<int>(out_pchg.size()) > nbands)
47+
// Loop over k-parallelism
48+
for (int ip = 0; ip < kpar; ++ip)
4649
{
47-
ModuleBase::WARNING_QUIT("ModuleIO::get_pchg_pw",
48-
"The number of bands specified by `out_pchg` in the "
49-
"INPUT file exceeds `nbands`!");
50-
}
51-
52-
// Check if all elements in bands_picked are 0 or 1
53-
for (int value: out_pchg)
54-
{
55-
if (value != 0 && value != 1)
50+
if (my_pool != ip)
5651
{
57-
ModuleBase::WARNING_QUIT("ModuleIO::get_pchg_pw",
58-
"The elements of `out_pchg` must be either 0 or 1. "
59-
"Invalid values found!");
52+
continue;
6053
}
61-
}
6254

63-
// Fill bands_picked with values from out_pchg
64-
// Remaining bands are already set to 0
65-
int length = std::min(static_cast<int>(out_pchg.size()), nbands);
66-
for (int i = 0; i < length; ++i)
67-
{
68-
// out_pchg rely on function parse_expression
69-
bands_picked[i] = static_cast<int>(out_pchg[i]);
70-
}
71-
72-
std::vector<std::complex<double>> wfcr(nxyz);
73-
std::vector<std::vector<double>> rho_band(nspin, std::vector<double>(nxyz));
55+
// bands_picked is a vector of 0s and 1s, where 1 means the band is picked to output
56+
std::vector<int> bands_picked(nbands, 0);
7457

75-
for (int ib = 0; ib < nbands; ++ib)
76-
{
77-
// Skip the loop iteration if bands_picked[ib] is 0
78-
if (!bands_picked[ib])
58+
// Check if length of out_pchg is valid
59+
if (static_cast<int>(out_pchg.size()) > nbands)
7960
{
80-
continue;
61+
ModuleBase::WARNING_QUIT("ModuleIO::get_pchg_pw",
62+
"The number of bands specified by `out_pchg` in the "
63+
"INPUT file exceeds `nbands`!");
8164
}
8265

83-
for (int is = 0; is < nspin; ++is)
66+
// Check if all elements in bands_picked are 0 or 1
67+
for (int value: out_pchg)
8468
{
85-
std::fill(rho_band[is].begin(), rho_band[is].end(), 0.0);
69+
if (value != 0 && value != 1)
70+
{
71+
ModuleBase::WARNING_QUIT("ModuleIO::get_pchg_pw",
72+
"The elements of `out_pchg` must be either 0 or 1. "
73+
"Invalid values found!");
74+
}
8675
}
8776

88-
if (if_separate_k)
77+
// Fill bands_picked with values from out_pchg
78+
// Remaining bands are already set to 0
79+
int length = std::min(static_cast<int>(out_pchg.size()), nbands);
80+
for (int i = 0; i < length; ++i)
8981
{
90-
for (int ik = 0; ik < nks; ++ik)
91-
{
92-
const int spin_index = isk[ik];
93-
std::cout << " Calculating band-decomposed charge density for band " << ib + 1 << ", k-point "
94-
<< ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl;
82+
// out_pchg rely on function parse_expression
83+
bands_picked[i] = static_cast<int>(out_pchg[i]);
84+
}
9585

96-
psi->fix_k(ik);
97-
pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
86+
std::vector<std::complex<double>> wfcr(nxyz);
87+
std::vector<std::vector<double>> rho_band(nspin, std::vector<double>(nxyz));
9888

99-
// To ensure the normalization of charge density in multi-k calculation (if if_separate_k is true)
100-
double wg_sum_k = 0;
101-
for (int ik_tmp = 0; ik_tmp < nks / nspin; ++ik_tmp)
102-
{
103-
wg_sum_k += wk[ik_tmp];
104-
}
89+
for (int ib = 0; ib < nbands; ++ib)
90+
{
91+
// Skip the loop iteration if bands_picked[ib] is 0
92+
if (!bands_picked[ib])
93+
{
94+
continue;
95+
}
10596

106-
double w1 = static_cast<double>(wg_sum_k / ucell->omega);
97+
for (int is = 0; is < nspin; ++is)
98+
{
99+
std::fill(rho_band[is].begin(), rho_band[is].end(), 0.0);
100+
}
107101

108-
for (int i = 0; i < nxyz; ++i)
102+
if (if_separate_k)
103+
{
104+
for (int ik = 0; ik < nks; ++ik)
109105
{
110-
rho_band[spin_index][i] = std::norm(wfcr[i]) * w1;
106+
const int ikstot = kv.ik2iktot[ik]; // global k-point index
107+
const int spin_index = kv.isk[ik]; // spin index
108+
const int k_number = ikstot % (nkstot / nspin) + 1; // k-point number, starting from 1
109+
110+
psi->fix_k(ik);
111+
pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
112+
113+
// To ensure the normalization of charge density in multi-k calculation (if if_separate_k is true)
114+
double wg_sum_k = 0.0;
115+
if (nspin == 1)
116+
{
117+
wg_sum_k = 2.0;
118+
}
119+
else if (nspin == 2)
120+
{
121+
wg_sum_k = 1.0;
122+
}
123+
else
124+
{
125+
ModuleBase::WARNING_QUIT("ModuleIO::get_pchg_pw",
126+
"Real space partial charge output currently do not support "
127+
"noncollinear polarized calculation (nspin = 4)!");
128+
}
129+
130+
double w1 = static_cast<double>(wg_sum_k / ucell->omega);
131+
132+
for (int i = 0; i < nxyz; ++i)
133+
{
134+
rho_band[spin_index][i] = std::norm(wfcr[i]) * w1;
135+
}
136+
137+
std::stringstream ssc;
138+
ssc << global_out_dir << "BAND" << ib + 1 << "_K" << k_number << "_SPIN" << spin_index + 1
139+
<< "_CHG.cube";
140+
141+
ModuleIO::write_vdata_palgrid(pgrid,
142+
rho_band[spin_index].data(),
143+
spin_index,
144+
nspin,
145+
0,
146+
ssc.str(),
147+
0.0,
148+
ucell,
149+
11,
150+
1,
151+
true); // reduce_all_pool is true
111152
}
112-
113-
std::cout << " Writing cube files...";
114-
115-
std::stringstream ssc;
116-
ssc << global_out_dir << "BAND" << ib + 1 << "_K" << ik % (nks / nspin) + 1 << "_SPIN" << spin_index + 1
117-
<< "_CHG.cube";
118-
119-
ModuleIO::write_vdata_palgrid(pgrid,
120-
rho_band[spin_index].data(),
121-
spin_index,
122-
nspin,
123-
0,
124-
ssc.str(),
125-
0.0,
126-
ucell);
127-
128-
std::cout << " Complete!" << std::endl;
129153
}
130-
}
131-
else
132-
{
133-
for (int ik = 0; ik < nks; ++ik)
154+
else
134155
{
135-
const int spin_index = isk[ik];
136-
std::cout << " Calculating band-decomposed charge density for band " << ib + 1 << ", k-point "
137-
<< ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl;
156+
for (int ik = 0; ik < nks; ++ik)
157+
{
158+
const int ikstot = kv.ik2iktot[ik]; // global k-point index
159+
const int spin_index = kv.isk[ik]; // spin index
160+
const int k_number = ikstot % (nkstot / nspin) + 1; // k-point number, starting from 1
138161

139-
psi->fix_k(ik);
140-
pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
162+
psi->fix_k(ik);
163+
pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
141164

142-
double w1 = static_cast<double>(wk[ik] / ucell->omega);
165+
double w1 = static_cast<double>(kv.wk[ik] / ucell->omega);
143166

144-
for (int i = 0; i < nxyz; ++i)
145-
{
146-
rho_band[spin_index][i] += std::norm(wfcr[i]) * w1;
167+
for (int i = 0; i < nxyz; ++i)
168+
{
169+
rho_band[spin_index][i] += std::norm(wfcr[i]) * w1;
170+
}
147171
}
148-
}
149172

150-
// Symmetrize the charge density, otherwise the results are incorrect if the symmetry is on
151-
std::cout << " Symmetrizing band-decomposed charge density..." << std::endl;
152-
Symmetry_rho srho;
153-
for (int is = 0; is < nspin; ++is)
154-
{
155-
// Use vector instead of raw pointers
156-
std::vector<double*> rho_save_pointers(nspin);
157-
for (int s = 0; s < nspin; ++s)
173+
// Reduce the charge density across all pools if kpar > 1
174+
if (kpar > 1 && chg != nullptr)
158175
{
159-
rho_save_pointers[s] = rho_band[s].data();
176+
for (int is = 0; is < nspin; ++is)
177+
{
178+
chg->reduce_diff_pools(rho_band[is].data());
179+
}
160180
}
161181

162-
std::vector<std::vector<std::complex<double>>> rhog(nspin, std::vector<std::complex<double>>(ngmc));
163-
164-
// Convert vector of vectors to vector of pointers
165-
std::vector<std::complex<double>*> rhog_pointers(nspin);
166-
for (int s = 0; s < nspin; ++s)
182+
// Symmetrize the charge density, otherwise the results are incorrect if the symmetry is on
183+
std::cout << " Symmetrizing band-decomposed charge density..." << std::endl;
184+
Symmetry_rho srho;
185+
for (int is = 0; is < nspin; ++is)
167186
{
168-
rhog_pointers[s] = rhog[s].data();
187+
// Use vector instead of raw pointers
188+
std::vector<double*> rho_save_pointers(nspin);
189+
for (int s = 0; s < nspin; ++s)
190+
{
191+
rho_save_pointers[s] = rho_band[s].data();
192+
}
193+
194+
std::vector<std::vector<std::complex<double>>> rhog(nspin, std::vector<std::complex<double>>(ngmc));
195+
196+
// Convert vector of vectors to vector of pointers
197+
std::vector<std::complex<double>*> rhog_pointers(nspin);
198+
for (int s = 0; s < nspin; ++s)
199+
{
200+
rhog_pointers[s] = rhog[s].data();
201+
}
202+
203+
srho.begin(is, rho_save_pointers.data(), rhog_pointers.data(), ngmc, nullptr, pw_rhod, ucell->symm);
169204
}
170205

171-
srho.begin(is,
172-
rho_save_pointers.data(),
173-
rhog_pointers.data(),
174-
ngmc,
175-
nullptr,
176-
pw_rhod,
177-
ucell->symm);
178-
}
179-
180-
std::cout << " Writing cube files...";
206+
for (int is = 0; is < nspin; ++is)
207+
{
208+
std::stringstream ssc;
209+
ssc << global_out_dir << "BAND" << ib + 1 << "_SPIN" << is + 1 << "_CHG.cube";
181210

182-
for (int is = 0; is < nspin; ++is)
183-
{
184-
std::stringstream ssc;
185-
ssc << global_out_dir << "BAND" << ib + 1 << "_SPIN" << is + 1 << "_CHG.cube";
186-
187-
ModuleIO::write_vdata_palgrid(pgrid,
188-
rho_band[is].data(),
189-
is,
190-
nspin,
191-
0,
192-
ssc.str(),
193-
0.0,
194-
ucell);
211+
ModuleIO::write_vdata_palgrid(pgrid, rho_band[is].data(), is, nspin, 0, ssc.str(), 0.0, ucell);
212+
}
195213
}
196-
197-
std::cout << " Complete!" << std::endl;
198214
}
199215
}
200216
}

0 commit comments

Comments
 (0)