Skip to content

Commit 5f8f70e

Browse files
Fix: Tests
1 parent 8493b5e commit 5f8f70e

File tree

5 files changed

+159
-152
lines changed

5 files changed

+159
-152
lines changed

source/module_cell/k_vector_utils.cpp

Lines changed: 130 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
#include <module_base/formatter.h>
1111
#include <module_parameter/parameter.h>
12+
#include <module_base/parallel_common.h>
13+
#include <module_base/parallel_reduce.h>
1214

1315
namespace KVectorUtils
1416
{
15-
void k_vec_d2c(K_Vectors& kv, const ModuleBase::Matrix3& reciprocal_vec)
17+
void kvec_d2c(K_Vectors& kv, const ModuleBase::Matrix3& reciprocal_vec)
1618
{
1719
// throw std::runtime_error("k_vec_d2c: This function is not implemented in the new codebase. Please use the new implementation.");
1820
if (kv.kvec_d.size() != kv.kvec_c.size())
@@ -56,7 +58,7 @@ void k_vec_d2c(K_Vectors& kv, const ModuleBase::Matrix3& reciprocal_vec)
5658
}
5759
}
5860
}
59-
void k_vec_c2d(K_Vectors& kv, const ModuleBase::Matrix3& latvec)
61+
void kvec_c2d(K_Vectors& kv, const ModuleBase::Matrix3& latvec)
6062
{
6163
if (kv.kvec_d.size() != kv.kvec_c.size())
6264
{
@@ -108,14 +110,14 @@ void set_both_kvec(K_Vectors& kv, const ModuleBase::Matrix3& G, const ModuleBase
108110
// set cartesian k vectors.
109111
if (!kv.kc_done && kv.kd_done)
110112
{
111-
KVectorUtils::k_vec_d2c(kv, G);
113+
KVectorUtils::kvec_d2c(kv, G);
112114
kv.kc_done = true;
113115
}
114116

115117
// set direct k vectors
116118
else if (kv.kc_done && !kv.kd_done)
117119
{
118-
KVectorUtils::k_vec_c2d(kv, R);
120+
KVectorUtils::kvec_c2d(kv, R);
119121
kv.kd_done = true;
120122
}
121123
std::string table;
@@ -150,7 +152,7 @@ void set_after_vc(K_Vectors& kv, const int& nspin_in, const ModuleBase::Matrix3&
150152
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "nspin", kv.get_nspin());
151153

152154
// set cartesian k vectors.
153-
KVectorUtils::k_vec_d2c(kv, reciprocal_vec);
155+
KVectorUtils::kvec_d2c(kv, reciprocal_vec);
154156

155157
std::string table;
156158
table += "K-POINTS DIRECT COORDINATES\n";
@@ -213,4 +215,127 @@ void print_klists(const K_Vectors& kv, std::ofstream& ofs)
213215
GlobalV::ofs_running << "\n" << table << std::endl;
214216
return;
215217
}
218+
219+
#ifdef __MPI
220+
void kvec_mpi_k(K_Vectors& kv)
221+
{
222+
ModuleBase::TITLE("KVectorUtils", "kvec_mpi_k");
223+
224+
Parallel_Common::bcast_bool(kv.kc_done);
225+
226+
Parallel_Common::bcast_bool(kv.kd_done);
227+
228+
Parallel_Common::bcast_int(kv.nspin);
229+
230+
Parallel_Common::bcast_int(kv.nkstot);
231+
232+
Parallel_Common::bcast_int(kv.nkstot_full);
233+
234+
Parallel_Common::bcast_int(kv.nmp, 3);
235+
236+
kv.kl_segids.resize(kv.nkstot);
237+
Parallel_Common::bcast_int(kv.kl_segids.data(), kv.nkstot);
238+
239+
Parallel_Common::bcast_double(kv.koffset, 3);
240+
241+
kv.nks = kv.para_k.nks_pool[GlobalV::MY_POOL];
242+
243+
GlobalV::ofs_running << std::endl;
244+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "k-point number in this process", kv.nks);
245+
int nks_minimum = kv.nks;
246+
247+
Parallel_Reduce::gather_min_int_all(GlobalV::NPROC, nks_minimum);
248+
249+
if (nks_minimum == 0)
250+
{
251+
ModuleBase::WARNING_QUIT("K_Vectors::mpi_k()", " nks == 0, some processor have no k point!");
252+
}
253+
else
254+
{
255+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "minimum distributed K point number", nks_minimum);
256+
}
257+
258+
std::vector<int> isk_aux(kv.nkstot);
259+
std::vector<double> wk_aux(kv.nkstot);
260+
std::vector<double> kvec_c_aux(kv.nkstot * 3);
261+
std::vector<double> kvec_d_aux(kv.nkstot * 3);
262+
263+
// collect and process in rank 0
264+
if (GlobalV::MY_RANK == 0)
265+
{
266+
for (int ik = 0; ik < kv.nkstot; ik++)
267+
{
268+
isk_aux[ik] = kv.isk[ik];
269+
wk_aux[ik] = kv.wk[ik];
270+
kvec_c_aux[3 * ik] = kv.kvec_c[ik].x;
271+
kvec_c_aux[3 * ik + 1] = kv.kvec_c[ik].y;
272+
kvec_c_aux[3 * ik + 2] = kv.kvec_c[ik].z;
273+
kvec_d_aux[3 * ik] = kv.kvec_d[ik].x;
274+
kvec_d_aux[3 * ik + 1] = kv.kvec_d[ik].y;
275+
kvec_d_aux[3 * ik + 2] = kv.kvec_d[ik].z;
276+
}
277+
}
278+
279+
// broadcast k point data to all processors
280+
Parallel_Common::bcast_int(isk_aux.data(), kv.nkstot);
281+
282+
Parallel_Common::bcast_double(wk_aux.data(), kv.nkstot);
283+
Parallel_Common::bcast_double(kvec_c_aux.data(), kv.nkstot * 3);
284+
Parallel_Common::bcast_double(kvec_d_aux.data(), kv.nkstot * 3);
285+
286+
// process k point data in each processor
287+
kv.renew(kv.nks * kv.nspin);
288+
289+
// distribute
290+
int k_index = 0;
291+
292+
for (int i = 0; i < kv.nks; i++)
293+
{
294+
// 3 is because each k point has three value:kx, ky, kz
295+
k_index = i + kv.para_k.startk_pool[GlobalV::MY_POOL];
296+
kv.kvec_c[i].x = kvec_c_aux[k_index * 3];
297+
kv.kvec_c[i].y = kvec_c_aux[k_index * 3 + 1];
298+
kv.kvec_c[i].z = kvec_c_aux[k_index * 3 + 2];
299+
kv.kvec_d[i].x = kvec_d_aux[k_index * 3];
300+
kv.kvec_d[i].y = kvec_d_aux[k_index * 3 + 1];
301+
kv.kvec_d[i].z = kvec_d_aux[k_index * 3 + 2];
302+
kv.wk[i] = wk_aux[k_index];
303+
kv.isk[i] = isk_aux[k_index];
304+
}
305+
306+
#ifdef __EXX
307+
if (ModuleSymmetry::Symmetry::symm_flag == 1)
308+
{//bcast kstars
309+
kv.kstars.resize(kv.nkstot);
310+
for (int ikibz = 0;ikibz < kv.nkstot;++ikibz)
311+
{
312+
int starsize = kv.kstars[ikibz].size();
313+
Parallel_Common::bcast_int(starsize);
314+
GlobalV::ofs_running << "starsize: " << starsize << std::endl;
315+
auto ks = kv.kstars[ikibz].begin();
316+
for (int ik = 0;ik < starsize;++ik)
317+
{
318+
int isym = 0;
319+
ModuleBase::Vector3<double> ks_vec(0, 0, 0);
320+
if (GlobalV::MY_RANK == 0)
321+
{
322+
isym = ks->first;
323+
ks_vec = ks->second;
324+
++ks;
325+
}
326+
Parallel_Common::bcast_int(isym);
327+
Parallel_Common::bcast_double(ks_vec.x);
328+
Parallel_Common::bcast_double(ks_vec.y);
329+
Parallel_Common::bcast_double(ks_vec.z);
330+
GlobalV::ofs_running << "isym: " << isym << " ks_vec: " << ks_vec.x << " " << ks_vec.y << " " << ks_vec.z << std::endl;
331+
if (GlobalV::MY_RANK != 0)
332+
{
333+
kv.kstars[ikibz].insert(std::make_pair(isym, ks_vec));
334+
}
335+
}
336+
}
337+
}
338+
#endif
339+
} // END SUBROUTINE
340+
#endif
216341
} // namespace KVectorUtils

source/module_cell/k_vector_utils.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ class K_Vectors;
1212

1313
namespace KVectorUtils
1414
{
15-
void k_vec_d2c(K_Vectors& kv, const ModuleBase::Matrix3& reciprocal_vec);
15+
void kvec_d2c(K_Vectors& kv, const ModuleBase::Matrix3& reciprocal_vec);
1616

17-
void k_vec_c2d(K_Vectors& kv, const ModuleBase::Matrix3& latvec);
17+
void kvec_c2d(K_Vectors& kv, const ModuleBase::Matrix3& latvec);
1818

1919
/**
2020
* @brief Sets both the direct and Cartesian k-vectors.
@@ -77,6 +77,29 @@ void set_after_vc(K_Vectors& kv, const int& nspin, const ModuleBase::Matrix3& G)
7777
* @note The function uses the FmtCore::format function to format the output.
7878
*/
7979
void print_klists(const K_Vectors& kv, std::ofstream& ofs);
80+
81+
// step 3 : mpi kpoints information.
82+
83+
/**
84+
* @brief Distributes k-points among MPI processes.
85+
*
86+
* This function distributes the k-points among the MPI processes. Each process gets a subset of the k-points to
87+
* work on. The function also broadcasts various variables related to the k-points to all processes.
88+
*
89+
* @param kv The K_Vectors object containing the k-point information.
90+
*
91+
* @return void
92+
*
93+
* @note This function is only compiled and used if MPI is enabled.
94+
* @note The function assumes that the number of k-points (nkstot) is greater than 0.
95+
* @note The function broadcasts the flags kc_done and kd_done, the number of spins (nspin), the total number of
96+
* k-points (nkstot), the full number of k-points (nkstot_full), the Monkhorst-Pack grid (nmp), the k-point offsets
97+
* (koffset), and the segment IDs of the k-points (kl_segids).
98+
* @note The function also broadcasts the indices of the k-points (isk), their weights (wk), and their Cartesian and
99+
* direct coordinates (kvec_c and kvec_d).
100+
* @note If a process has no k-points to work on, the function will quit with an error message.
101+
*/
102+
void kvec_mpi_k(K_Vectors& kv);
80103
} // namespace KVectorUtils
81104

82105
#endif // K_VECTOR_UTILS_H

source/module_cell/klist.cpp

Lines changed: 1 addition & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ void K_Vectors::set(const UnitCell& ucell,
154154
nspin_in); // assign k points to several process pools
155155
#ifdef __MPI
156156
// distribute K point data to the corresponding process
157-
this->mpi_k(); // 2008-4-29
157+
KVectorUtils::kvec_mpi_k(*this);
158158
#endif
159159

160160
// set the k vectors for the up and down spin
@@ -1027,129 +1027,6 @@ void K_Vectors::normalize_wk(const int& degspin)
10271027
return;
10281028
}
10291029

1030-
#ifdef __MPI
1031-
void K_Vectors::mpi_k()
1032-
{
1033-
ModuleBase::TITLE("K_Vectors", "mpi_k");
1034-
1035-
Parallel_Common::bcast_bool(kc_done);
1036-
1037-
Parallel_Common::bcast_bool(kd_done);
1038-
1039-
Parallel_Common::bcast_int(nspin);
1040-
1041-
Parallel_Common::bcast_int(nkstot);
1042-
1043-
Parallel_Common::bcast_int(nkstot_full);
1044-
1045-
Parallel_Common::bcast_int(nmp, 3);
1046-
1047-
kl_segids.resize(nkstot);
1048-
Parallel_Common::bcast_int(kl_segids.data(), nkstot);
1049-
1050-
Parallel_Common::bcast_double(koffset, 3);
1051-
1052-
this->nks = this->para_k.nks_pool[GlobalV::MY_POOL];
1053-
1054-
GlobalV::ofs_running << std::endl;
1055-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "k-point number in this process", nks);
1056-
int nks_minimum = this->nks;
1057-
1058-
Parallel_Reduce::gather_min_int_all(GlobalV::NPROC, nks_minimum);
1059-
1060-
if (nks_minimum == 0)
1061-
{
1062-
ModuleBase::WARNING_QUIT("K_Vectors::mpi_k()", " nks == 0, some processor have no k point!");
1063-
}
1064-
else
1065-
{
1066-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "minimum distributed K point number", nks_minimum);
1067-
}
1068-
1069-
std::vector<int> isk_aux(nkstot);
1070-
std::vector<double> wk_aux(nkstot);
1071-
std::vector<double> kvec_c_aux(nkstot * 3);
1072-
std::vector<double> kvec_d_aux(nkstot * 3);
1073-
1074-
// collect and process in rank 0
1075-
if (GlobalV::MY_RANK == 0)
1076-
{
1077-
for (int ik = 0; ik < nkstot; ik++)
1078-
{
1079-
isk_aux[ik] = isk[ik];
1080-
wk_aux[ik] = wk[ik];
1081-
kvec_c_aux[3 * ik] = kvec_c[ik].x;
1082-
kvec_c_aux[3 * ik + 1] = kvec_c[ik].y;
1083-
kvec_c_aux[3 * ik + 2] = kvec_c[ik].z;
1084-
kvec_d_aux[3 * ik] = kvec_d[ik].x;
1085-
kvec_d_aux[3 * ik + 1] = kvec_d[ik].y;
1086-
kvec_d_aux[3 * ik + 2] = kvec_d[ik].z;
1087-
}
1088-
}
1089-
1090-
// broadcast k point data to all processors
1091-
Parallel_Common::bcast_int(isk_aux.data(), nkstot);
1092-
1093-
Parallel_Common::bcast_double(wk_aux.data(), nkstot);
1094-
Parallel_Common::bcast_double(kvec_c_aux.data(), nkstot * 3);
1095-
Parallel_Common::bcast_double(kvec_d_aux.data(), nkstot * 3);
1096-
1097-
// process k point data in each processor
1098-
this->renew(this->nks * this->nspin);
1099-
1100-
// distribute
1101-
int k_index = 0;
1102-
1103-
for (int i = 0; i < nks; i++)
1104-
{
1105-
// 3 is because each k point has three value:kx, ky, kz
1106-
k_index = i + this->para_k.startk_pool[GlobalV::MY_POOL];
1107-
kvec_c[i].x = kvec_c_aux[k_index * 3];
1108-
kvec_c[i].y = kvec_c_aux[k_index * 3 + 1];
1109-
kvec_c[i].z = kvec_c_aux[k_index * 3 + 2];
1110-
kvec_d[i].x = kvec_d_aux[k_index * 3];
1111-
kvec_d[i].y = kvec_d_aux[k_index * 3 + 1];
1112-
kvec_d[i].z = kvec_d_aux[k_index * 3 + 2];
1113-
wk[i] = wk_aux[k_index];
1114-
isk[i] = isk_aux[k_index];
1115-
}
1116-
1117-
#ifdef __EXX
1118-
if (ModuleSymmetry::Symmetry::symm_flag == 1)
1119-
{//bcast kstars
1120-
this->kstars.resize(nkstot);
1121-
for (int ikibz = 0;ikibz < nkstot;++ikibz)
1122-
{
1123-
int starsize = this->kstars[ikibz].size();
1124-
Parallel_Common::bcast_int(starsize);
1125-
GlobalV::ofs_running << "starsize: " << starsize << std::endl;
1126-
auto ks = this->kstars[ikibz].begin();
1127-
for (int ik = 0;ik < starsize;++ik)
1128-
{
1129-
int isym = 0;
1130-
ModuleBase::Vector3<double> ks_vec(0, 0, 0);
1131-
if (GlobalV::MY_RANK == 0)
1132-
{
1133-
isym = ks->first;
1134-
ks_vec = ks->second;
1135-
++ks;
1136-
}
1137-
Parallel_Common::bcast_int(isym);
1138-
Parallel_Common::bcast_double(ks_vec.x);
1139-
Parallel_Common::bcast_double(ks_vec.y);
1140-
Parallel_Common::bcast_double(ks_vec.z);
1141-
GlobalV::ofs_running << "isym: " << isym << " ks_vec: " << ks_vec.x << " " << ks_vec.y << " " << ks_vec.z << std::endl;
1142-
if (GlobalV::MY_RANK != 0)
1143-
{
1144-
kstars[ikibz].insert(std::make_pair(isym, ks_vec));
1145-
}
1146-
}
1147-
}
1148-
}
1149-
#endif
1150-
} // END SUBROUTINE
1151-
#endif
1152-
11531030
//----------------------------------------------------------
11541031
// This routine sets the k vectors for the up and down spin
11551032
//----------------------------------------------------------

source/module_cell/klist.h

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -299,26 +299,7 @@ class K_Vectors
299299
*/
300300
void normalize_wk(const int& degspin);
301301

302-
// step 3 : mpi kpoints information.
303302

304-
/**
305-
* @brief Distributes k-points among MPI processes.
306-
*
307-
* This function distributes the k-points among the MPI processes. Each process gets a subset of the k-points to
308-
* work on. The function also broadcasts various variables related to the k-points to all processes.
309-
*
310-
* @return void
311-
*
312-
* @note This function is only compiled and used if MPI is enabled.
313-
* @note The function assumes that the number of k-points (nkstot) is greater than 0.
314-
* @note The function broadcasts the flags kc_done and kd_done, the number of spins (nspin), the total number of
315-
* k-points (nkstot), the full number of k-points (nkstot_full), the Monkhorst-Pack grid (nmp), the k-point offsets
316-
* (koffset), and the segment IDs of the k-points (kl_segids).
317-
* @note The function also broadcasts the indices of the k-points (isk), their weights (wk), and their Cartesian and
318-
* direct coordinates (kvec_c and kvec_d).
319-
* @note If a process has no k-points to work on, the function will quit with an error message.
320-
*/
321-
void mpi_k();
322303

323304
// step 4 : *2 kpoints.
324305

@@ -349,5 +330,6 @@ class K_Vectors
349330
*/
350331
void cal_ik_global();
351332

333+
friend void KVectorUtils::kvec_mpi_k(K_Vectors& kvec);
352334
};
353335
#endif // KVECT_H

0 commit comments

Comments
 (0)