Skip to content

Commit aced178

Browse files
Fix: wrong band output when kpar > 1 (#5847)
* fix wrong band results when kpar>1 * fix compile * fix: bug for init_wfc file when nspin = 2 * add new integral test * format read_wfc_to_rho_test * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 806de4a commit aced178

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+368
-366
lines changed

source/module_cell/klist.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,33 @@
1313
#include "module_cell/module_paw/paw_cell.h"
1414
#endif
1515

16-
K_Vectors::K_Vectors()
16+
void K_Vectors::cal_ik_global()
1717
{
18-
19-
nspin = 0; // default spin.
20-
kc_done = false;
21-
kd_done = false;
22-
nkstot_full = 0;
23-
nks = 0;
24-
nkstot = 0;
25-
k_nkstot = 0; // LiuXh add 20180619
26-
}
27-
28-
K_Vectors::~K_Vectors()
29-
{
30-
}
31-
32-
int K_Vectors::get_ik_global(const int& ik, const int& nkstot)
33-
{
34-
int nkp = nkstot / PARAM.inp.kpar;
35-
int rem = nkstot % PARAM.inp.kpar;
36-
if (GlobalV::MY_POOL < rem)
18+
const int my_pool = this->para_k.my_pool;
19+
this->ik2iktot.resize(this->nks);
20+
#ifdef __MPI
21+
if(this->nspin == 2)
3722
{
38-
return GlobalV::MY_POOL * nkp + GlobalV::MY_POOL + ik;
23+
for (int ik = 0; ik < this->nks / 2; ++ik)
24+
{
25+
this->ik2iktot[ik] = this->para_k.startk_pool[my_pool] + ik;
26+
this->ik2iktot[ik + this->nks / 2] = this->nkstot / 2 + this->para_k.startk_pool[my_pool] + ik;
27+
}
3928
}
4029
else
4130
{
42-
return GlobalV::MY_POOL * nkp + rem + ik;
31+
for (int ik = 0; ik < this->nks; ++ik)
32+
{
33+
this->ik2iktot[ik] = this->para_k.startk_pool[my_pool] + ik;
34+
}
4335
}
36+
#else
37+
for (int ik = 0; ik < this->nks; ++ik)
38+
{
39+
this->ik2iktot[ik] = ik;
40+
}
41+
#endif
42+
4443
}
4544

4645
void K_Vectors::set(const UnitCell& ucell,
@@ -162,6 +161,9 @@ void K_Vectors::set(const UnitCell& ucell,
162161
// set the k vectors for the up and down spin
163162
this->set_kup_and_kdw();
164163

164+
// get ik2iktot
165+
this->cal_ik_global();
166+
165167
this->print_klists(ofs);
166168

167169
// std::cout << " NUMBER OF K-POINTS : " << nkstot << std::endl;

source/module_cell/klist.h

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class K_Vectors
2626
/// dim: [iks_ibz][(isym, kvec_d)]
2727
std::vector<std::map<int, ModuleBase::Vector3<double>>> kstars;
2828

29-
K_Vectors();
30-
~K_Vectors();
29+
K_Vectors(){};
30+
~K_Vectors(){};
3131
K_Vectors& operator=(const K_Vectors&) = default;
3232
K_Vectors& operator=(K_Vectors&& rhs) = default;
3333

@@ -106,23 +106,6 @@ class K_Vectors
106106
*/
107107
void set_after_vc(const int& nspin, const ModuleBase::Matrix3& reciprocal_vec, const ModuleBase::Matrix3& latvec);
108108

109-
/**
110-
* @brief Gets the global index of a k-point.
111-
*
112-
* This function gets the global index of a k-point based on its local index and the process pool ID.
113-
* The global index is used when the k-points are distributed among multiple process pools.
114-
*
115-
* @param nkstot The total number of k-points.
116-
* @param ik The local index of the k-point.
117-
*
118-
* @return int Returns the global index of the k-point.
119-
*
120-
* @note The function calculates the global index by dividing the total number of k-points (nkstot) by the number of
121-
* process pools (KPAR), and adding the remainder if the process pool ID (MY_POOL) is less than the remainder.
122-
* @note The function is declared as inline for efficiency.
123-
*/
124-
static int get_ik_global(const int& ik, const int& nkstot);
125-
126109
int get_nks() const
127110
{
128111
return this->nks;
@@ -157,19 +140,20 @@ class K_Vectors
157140
{
158141
this->nkstot_full = value;
159142
}
143+
std::vector<int> ik2iktot; ///<[nks] map ik to the global index of k points
160144

161-
private:
162-
int nks; // number of symmetry-reduced k points in this pool(processor, up+dw)
163-
int nkstot; /// number of symmetry-reduced k points in full k mesh
164-
int nkstot_full; /// number of k points before symmetry reduction in full k mesh
145+
private:
146+
int nks = 0; ///< number of symmetry-reduced k points in this pool(processor, up+dw)
147+
int nkstot = 0; ///< number of symmetry-reduced k points in full k mesh
148+
int nkstot_full = 0; ///< number of k points before symmetry reduction in full k mesh
165149

166-
int nspin;
167-
bool kc_done;
168-
bool kd_done;
169-
double koffset[3]={0.0}; // used only in automatic k-points.
170-
std::string k_kword; // LiuXh add 20180619
171-
int k_nkstot; // LiuXh add 20180619
172-
bool is_mp = false; // Monkhorst-Pack
150+
int nspin = 0;
151+
bool kc_done = false;
152+
bool kd_done = false;
153+
double koffset[3] = {0.0}; // used only in automatic k-points.
154+
std::string k_kword; // LiuXh add 20180619
155+
int k_nkstot = 0; // LiuXh add 20180619
156+
bool is_mp = false; // Monkhorst-Pack
173157

174158
/**
175159
* @brief Resize the k-point related vectors according to the new k-point number.
@@ -288,8 +272,8 @@ class K_Vectors
288272
* be recalculated.
289273
*/
290274
void update_use_ibz(const int& nkstot_ibz,
291-
const std::vector<ModuleBase::Vector3<double>>& kvec_d_ibz,
292-
const std::vector<double>& wk_ibz);
275+
const std::vector<ModuleBase::Vector3<double>>& kvec_d_ibz,
276+
const std::vector<double>& wk_ibz);
293277

294278
/**
295279
* @brief Sets both the direct and Cartesian k-vectors.
@@ -394,5 +378,11 @@ class K_Vectors
394378
* @note The function uses the FmtCore::format function to format the output.
395379
*/
396380
void print_klists(std::ofstream& fn);
381+
382+
/**
383+
* @brief Gets the global index of a k-point.
384+
* @return this->ik2iktot[ik]
385+
*/
386+
void cal_ik_global();
397387
};
398388
#endif // KVECT_H

source/module_cell/parallel_kpoints.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,13 @@ class Parallel_Kpoints
6868
return *std::max_element(nks_pool.begin(), nks_pool.end());
6969
}
7070

71-
private:
72-
71+
public:
7372
int kpar = 0; // number of pools
7473
int my_pool = 0; // the pool index of the present processor
7574
int rank_in_pool = 0; // the rank in the present pool
7675
int nproc = 1; // number of processors
7776
int nspin = 1; // number of spins
78-
77+
private:
7978
std::vector<int> startpro_pool; // the first processor in each pool
8079
#ifdef __MPI
8180
void get_nks_pool(const int& nkstot);

source/module_elecstate/elecstate_print.cpp

Lines changed: 63 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "elecstate_getters.h"
33
#include "module_base/formatter.h"
44
#include "module_base/global_variable.h"
5+
#include "module_base/parallel_common.h"
56
#include "module_elecstate/potentials/H_Hartree_pw.h"
67
#include "module_elecstate/potentials/efield.h"
78
#include "module_elecstate/potentials/gatefield.h"
@@ -152,7 +153,9 @@ void print_scf_iterinfo(const std::string& ks_solver,
152153
void ElecState::print_eigenvalue(std::ofstream& ofs)
153154
{
154155
bool wrong = false;
155-
for (int ik = 0; ik < this->klist->get_nks(); ++ik)
156+
const int nks = this->klist->get_nks();
157+
const int nkstot = this->klist->get_nkstot();
158+
for (int ik = 0; ik < nks; ++ik)
156159
{
157160
for (int ib = 0; ib < this->ekb.nc; ++ib)
158161
{
@@ -164,76 +167,87 @@ void ElecState::print_eigenvalue(std::ofstream& ofs)
164167
}
165168
}
166169
}
170+
#ifdef __MPI
171+
MPI_Allreduce(MPI_IN_PLACE, &wrong, 1, MPI_C_BOOL, MPI_LOR, MPI_COMM_WORLD);
172+
#endif
167173
if (wrong)
168174
{
169175
ModuleBase::WARNING_QUIT("print_eigenvalue", "Eigenvalues are too large!");
170176
}
177+
std::stringstream ss;
178+
if(PARAM.inp.out_alllog)
179+
{
180+
ss << PARAM.globalv.global_out_dir << "running_" << PARAM.inp.calculation << "_" << GlobalV::MY_RANK + 1 << ".log";
181+
}
182+
else
183+
{
184+
ss << PARAM.globalv.global_out_dir << "running_" << PARAM.inp.calculation << ".log";
185+
}
186+
std::string filename = ss.str();
187+
std::vector<int> ngk_tot = this->klist->ngk;
171188

172-
if (GlobalV::MY_RANK != 0)
189+
#ifdef __MPI
190+
if(!PARAM.inp.out_alllog)
173191
{
174-
return;
192+
Parallel_Common::bcast_string(filename);
175193
}
194+
MPI_Allreduce(MPI_IN_PLACE, ngk_tot.data(), nks, MPI_INT, MPI_SUM, POOL_WORLD);
195+
#endif
176196

177197
ModuleBase::TITLE("ESolver_KS_PW", "print_eigenvalue");
178198

179199
ofs << "\n STATE ENERGY(eV) AND OCCUPATIONS ";
180-
for (int ik = 0; ik < this->klist->get_nks(); ik++)
200+
const int nk_fac = PARAM.inp.nspin == 2 ? 2 : 1;
201+
const int nks_np = nks / nk_fac;
202+
const int nkstot_np = nkstot / nk_fac;
203+
ofs << " NSPIN == " << PARAM.inp.nspin << std::endl;
204+
for (int is = 0; is < nk_fac; ++is)
181205
{
182-
ofs << std::setprecision(5);
183-
ofs << std::setiosflags(std::ios::showpoint);
184-
if (ik == 0)
206+
if (is == 0 && nk_fac == 2)
185207
{
186-
ofs << " NSPIN == " << PARAM.inp.nspin << std::endl;
187-
if (PARAM.inp.nspin == 2)
188-
{
189-
ofs << "SPIN UP : " << std::endl;
190-
}
208+
ofs << "SPIN UP : " << std::endl;
191209
}
192-
else if (ik == this->klist->get_nks() / 2)
210+
else if (is == 1 && nk_fac == 2)
193211
{
194-
if (PARAM.inp.nspin == 2)
195-
{
196-
ofs << "SPIN DOWN : " << std::endl;
197-
}
212+
ofs << "SPIN DOWN : " << std::endl;
198213
}
199214

200-
if (PARAM.inp.nspin == 2)
215+
for (int ip = 0; ip < GlobalV::KPAR; ++ip)
201216
{
202-
if (this->klist->isk[ik] == 0)
203-
{
204-
ofs << " " << ik + 1 << "/" << this->klist->get_nks() / 2
205-
<< " kpoint (Cartesian) = " << this->klist->kvec_c[ik].x << " " << this->klist->kvec_c[ik].y << " "
206-
<< this->klist->kvec_c[ik].z << " (" << this->klist->ngk[ik] << " pws)" << std::endl;
207-
208-
ofs << std::setprecision(6);
209-
}
210-
if (this->klist->isk[ik] == 1)
217+
#ifdef __MPI
218+
MPI_Barrier(MPI_COMM_WORLD);
219+
#endif
220+
bool ip_flag = PARAM.inp.out_alllog || (GlobalV::RANK_IN_POOL == 0 && GlobalV::MY_STOGROUP == 0);
221+
if (GlobalV::MY_POOL == ip && ip_flag)
211222
{
212-
ofs << " " << ik + 1 - this->klist->get_nks() / 2 << "/" << this->klist->get_nks() / 2
213-
<< " kpoint (Cartesian) = " << this->klist->kvec_c[ik].x << " " << this->klist->kvec_c[ik].y << " "
214-
<< this->klist->kvec_c[ik].z << " (" << this->klist->ngk[ik] << " pws)" << std::endl;
223+
const int start_ik = nks_np * is;
224+
const int end_ik = nks_np * (is + 1);
225+
for (int ik = start_ik; ik < end_ik; ++ik)
226+
{
227+
std::ofstream ofs_eig(filename.c_str(), std::ios::app);
228+
ofs_eig << std::setprecision(5);
229+
ofs_eig << std::setiosflags(std::ios::showpoint);
230+
ofs_eig << " " << this->klist->ik2iktot[ik] + 1 - is * nkstot_np << "/" << nkstot_np
231+
<< " kpoint (Cartesian) = " << this->klist->kvec_c[ik].x << " " << this->klist->kvec_c[ik].y
232+
<< " " << this->klist->kvec_c[ik].z << " (" << ngk_tot[ik] << " pws)" << std::endl;
215233

216-
ofs << std::setprecision(6);
234+
ofs_eig << std::setprecision(6);
235+
ofs_eig << std::setiosflags(std::ios::showpoint);
236+
for (int ib = 0; ib < this->ekb.nc; ib++)
237+
{
238+
ofs_eig << std::setw(8) << ib + 1 << std::setw(15) << this->ekb(ik, ib) * ModuleBase::Ry_to_eV
239+
<< std::setw(15) << this->wg(ik, ib) << std::endl;
240+
}
241+
ofs_eig << std::endl;
242+
ofs_eig.close();
243+
}
217244
}
218-
} // Pengfei Li added 14-9-9
219-
else
220-
{
221-
ofs << " " << ik + 1 << "/" << this->klist->get_nks()
222-
<< " kpoint (Cartesian) = " << this->klist->kvec_c[ik].x << " " << this->klist->kvec_c[ik].y << " "
223-
<< this->klist->kvec_c[ik].z << " (" << this->klist->ngk[ik] << " pws)" << std::endl;
224-
225-
ofs << std::setprecision(6);
226245
}
227-
228-
ofs << std::setprecision(6);
229-
ofs << std::setiosflags(std::ios::showpoint);
230-
for (int ib = 0; ib < this->ekb.nc; ib++)
231-
{
232-
ofs << std::setw(8) << ib + 1 << std::setw(15) << this->ekb(ik, ib) * ModuleBase::Ry_to_eV << std::setw(15)
233-
<< this->wg(ik, ib) << std::endl;
234-
}
235-
ofs << std::endl;
236-
} // end ik
246+
#ifdef __MPI
247+
MPI_Barrier(MPI_COMM_WORLD);
248+
#endif
249+
ofs.seekp(0, std::ios::end);
250+
}
237251
return;
238252
}
239253

source/module_elecstate/module_charge/charge_init.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ void Charge::init_rho(elecstate::efermi& eferm_iout,
248248
const K_Vectors* kv = reinterpret_cast<const K_Vectors*>(klist);
249249
const int nkstot = kv->get_nkstot();
250250
const std::vector<int>& isk = kv->isk;
251-
ModuleIO::read_wfc_to_rho(pw_wfc, symm, nkstot, isk, *this);
251+
ModuleIO::read_wfc_to_rho(pw_wfc, symm, kv->ik2iktot.data(), nkstot, isk, *this);
252252
}
253253
}
254254

source/module_elecstate/module_dm/test/test_cal_dm_R.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,6 @@
66
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"
77
#include "module_cell/klist.h"
88

9-
K_Vectors::K_Vectors()
10-
{
11-
}
12-
13-
K_Vectors::~K_Vectors()
14-
{
15-
}
169
/************************************************
1710
* unit test of DensityMatrix constructor
1811
***********************************************/

source/module_elecstate/module_dm/test/test_dm_R_init.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,6 @@
77
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"
88
#include "module_cell/klist.h"
99
#undef private
10-
K_Vectors::K_Vectors()
11-
{
12-
}
13-
14-
K_Vectors::~K_Vectors()
15-
{
16-
}
1710
/************************************************
1811
* unit test of DensityMatrix constructor
1912
***********************************************/

source/module_elecstate/module_dm/test/test_dm_constructor.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,6 @@
55
#include "module_elecstate/module_dm/density_matrix.h"
66
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"
77
#include "module_cell/klist.h"
8-
K_Vectors::K_Vectors()
9-
{
10-
}
11-
K_Vectors::~K_Vectors()
12-
{
13-
}
148
/************************************************
159
* unit test of DensityMatrix constructor
1610
***********************************************/

source/module_elecstate/module_dm/test/test_dm_io.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,6 @@ Magnetism::~Magnetism()
3333
}
3434

3535
#include "module_cell/klist.h"
36-
37-
K_Vectors::K_Vectors()
38-
{
39-
}
40-
41-
K_Vectors::~K_Vectors()
42-
{
43-
}
44-
4536
#include "module_cell/module_neighbor/sltk_grid_driver.h"
4637
// mock find_atom() function
4738
void Grid_Driver::Find_atom(const UnitCell& ucell,

0 commit comments

Comments
 (0)