Skip to content

Commit 501580a

Browse files
committed
add the mpi_flag_
1 parent 82bf06f commit 501580a

File tree

11 files changed

+111
-124
lines changed

11 files changed

+111
-124
lines changed

source/module_basis/module_pw/pw_basis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class PW_Basis
6161
PW_Basis(std::string device_, std::string precision_);
6262
virtual ~PW_Basis();
6363
//Init mpi parameters
64+
void set_mpi(const bool mpi_flag_in);
6465
#ifdef __MPI
6566
void initmpi(
6667
const int poolnproc_in, // Number of processors in this pool
@@ -436,6 +437,7 @@ class PW_Basis
436437

437438
std::string device = "cpu"; ///< cpu or gpu
438439
std::string precision = "double"; ///< single, double, mixing
440+
bool mpi_flag_ = true; ///< ture,is use mpi or not
439441
bool double_data_ = true; ///< if has double data
440442
bool float_data_ = false; ///< if has float data
441443
};

source/module_basis/module_pw/pw_basis_big.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,12 @@ class PW_Basis_Big : public PW_Basis_Sup
166166
ibox[0] = 2*n1+1;
167167
ibox[1] = 2*n2+1;
168168
ibox[2] = 2*n3+1;
169+
if (mpi_flag_)
170+
{
169171
#ifdef __MPI
170172
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , this->pool_world);
171173
#endif
172-
174+
}
173175

174176
// Find the minimal FFT box size the factors into the primes (2,3,5,7).
175177
for (int i = 0; i < 3; i++)
@@ -350,9 +352,12 @@ class PW_Basis_Big : public PW_Basis_Sup
350352
}
351353
}
352354
}
355+
if (mpi_flag_)
356+
{
353357
#ifdef __MPI
354358
MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world);
355359
#endif
360+
}
356361
this->gridecut_lat -= 1e-6;
357362

358363
delete[] ibox;

source/module_basis/module_pw/pw_basis_k_big.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ class PW_Basis_K_Big: public PW_Basis_K
5656
for(int ip = 0 ; ip < this->poolnproc ; ++ip)
5757
{
5858
this->numz[ip] = npbz*this->bz;
59-
if(ip < modbz) { this->numz[ip]+=this->bz;
60-
}
59+
if(ip < modbz) { this->numz[ip]+=this->bz;}
6160
if(ip < this->poolnproc - 1) this->startz[ip+1] = this->startz[ip] + numz[ip];
6261
if(ip == this->poolrank)
6362
{

source/module_basis/module_pw/pw_basis_sup.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,25 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
113113
// calculate this->nstot and this->npwtot, liy, riy
114114
this->count_pw_st(st_length2D, st_bottom2D);
115115
}
116+
if (mpi_flag_)
117+
{
116118
#ifdef __MPI
117-
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
118-
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
119-
MPI_Bcast(&liy, 1, MPI_INT, 0, this->pool_world);
120-
MPI_Bcast(&riy, 1, MPI_INT, 0, this->pool_world);
121-
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
122-
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
119+
120+
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
121+
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
122+
MPI_Bcast(&liy, 1, MPI_INT, 0, this->pool_world);
123+
MPI_Bcast(&riy, 1, MPI_INT, 0, this->pool_world);
124+
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
125+
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
123126
#endif
127+
}
124128
delete[] this->istot2ixy;
125129
this->istot2ixy = new int[this->nstot];
126130

127131
if (poolrank == 0)
128132
{
133+
if (mpi_flag_)
134+
{
129135
#ifdef __MPI
130136
// Parallel line
131137
// (2) Collect the x, y indexs, and length of the sticks.
@@ -147,7 +153,8 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
147153
// We do not need startnsz_per after it.
148154
delete[] this->startnsz_per;
149155
this->startnsz_per = nullptr;
150-
#else
156+
#endif
157+
}else{
151158
// Serial line
152159
// get nst_per, npw_per, fftixy2ip, and istot2ixy
153160
this->nst_per[0] = this->nstot;
@@ -162,17 +169,20 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
162169
st_move++;
163170
}
164171
}
165-
#endif
166172
}
167-
173+
}
174+
if (mpi_flag_)
175+
{
168176
#ifdef __MPI
169-
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
170-
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
171-
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
172-
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
173-
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0, this->pool_world);
174-
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0, this->pool_world);
177+
178+
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
179+
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
180+
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
181+
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
182+
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0, this->pool_world);
183+
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0, this->pool_world);
175184
#endif
185+
}
176186
this->npw = this->npw_per[this->poolrank];
177187
this->nst = this->nst_per[this->poolrank];
178188
this->nstnz = this->nst * this->nz;

source/module_basis/module_pw/pw_distributeg_method1.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,23 @@ void PW_Basis::distribution_method1()
4646

4747
this->count_pw_st(st_length2D, st_bottom2D);
4848
}
49+
if (mpi_flag_)
50+
{
4951
#ifdef __MPI
50-
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
51-
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
52-
MPI_Bcast(&liy, 1, MPI_INT, 0, this->pool_world);
53-
MPI_Bcast(&riy, 1, MPI_INT, 0, this->pool_world);
54-
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
55-
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
52+
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
53+
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
54+
MPI_Bcast(&liy, 1, MPI_INT, 0, this->pool_world);
55+
MPI_Bcast(&riy, 1, MPI_INT, 0, this->pool_world);
56+
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
57+
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
5658
#endif
59+
}
5760
delete[] this->istot2ixy; this->istot2ixy = new int[this->nstot];
5861

5962
if(poolrank == 0)
6063
{
64+
if (mpi_flag_)
65+
{
6166
#ifdef __MPI
6267
// Parallel line
6368
// (2) Collect the x, y indexs, and length of the sticks.
@@ -78,8 +83,9 @@ void PW_Basis::distribution_method1()
7883
delete[] st_j;
7984
//We do not need startnsz_per after it.
8085
delete[] this->startnsz_per;
81-
this->startnsz_per=nullptr;
82-
#else
86+
this->startnsz_per=nullptr;
87+
#endif
88+
}else{
8389
// Serial line
8490
// get nst_per, npw_per, fftixy2ip, and istot2ixy
8591
this->nst_per[0] = this->nstot;
@@ -94,17 +100,19 @@ void PW_Basis::distribution_method1()
94100
st_move++;
95101
}
96102
}
97-
#endif
98103
}
99-
104+
}
105+
if (mpi_flag_)
106+
{
100107
#ifdef __MPI
101-
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
102-
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
103-
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
104-
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
105-
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
106-
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
108+
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
109+
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
110+
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
111+
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
112+
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
113+
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
107114
#endif
115+
}
108116
this->npw = this->npw_per[this->poolrank];
109117
this->nst = this->nst_per[this->poolrank];
110118
this->nstnz = this->nst * this->nz;

source/module_basis/module_pw/pw_distributeg_method2.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ void PW_Basis::distribution_method2()
4646
*/
4747
this->count_pw_st(st_length2D, st_bottom2D);
4848
}
49+
if (mpi_flag_)
50+
{
4951
#ifdef __MPI
5052
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
5153
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
@@ -54,11 +56,15 @@ void PW_Basis::distribution_method2()
5456
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
5557
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
5658
#endif
59+
}
5760
delete[] this->istot2ixy; this->istot2ixy = new int[this->nstot];
5861

5962
if(poolrank == 0)
6063
{
64+
if (mpi_flag_)
65+
{
6166
#ifdef __MPI
67+
6268
// Parallel line
6369
// (2) Devide the sticks to each core, sticks are in the order of ixy increasing.
6470
// get nst_per and startnsz_per
@@ -71,7 +77,8 @@ void PW_Basis::distribution_method2()
7177
//We do not need startnsz_per after it.
7278
delete[] this->startnsz_per;
7379
this->startnsz_per=nullptr;
74-
#else
80+
#endif
81+
}else{
7582
// Serial line
7683
// get nst_per, npw_per, fftixy2ip, and istot2ixy
7784
this->nst_per[0] = this->nstot;
@@ -86,15 +93,18 @@ void PW_Basis::distribution_method2()
8693
st_move++;
8794
}
8895
}
89-
#endif
96+
}
9097
}
9198
#ifdef __MPI
92-
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
93-
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
94-
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
95-
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
96-
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
97-
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
99+
if (mpi_flag_)
100+
{
101+
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
102+
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
103+
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
104+
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
105+
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
106+
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
107+
}
98108
#endif
99109
this->npw = this->npw_per[this->poolrank];
100110
this->nst = this->nst_per[this->poolrank];

source/module_basis/module_pw/pw_gatherscatter.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
3636
return;
3737
}
3838
#ifdef __MPI
39+
if (mpi_flag_)
40+
{
3941
//change (nplane fftnxy) to (nplane,nstot)
4042
// Hence, we can send them at one time.
4143
#ifdef _OPENMP
@@ -83,7 +85,7 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
8385
}
8486
}
8587
}
86-
88+
}
8789
#endif
8890
//ModuleBase::timer::tick(this->classname, "gatherp_scatters");
8991
return;
@@ -129,6 +131,8 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
129131
return;
130132
}
131133
#ifdef __MPI
134+
if (mpi_flag_)
135+
{
132136
// change (nz,ns) to (numz[ip],ns, poolnproc)
133137
// Hence, we can send them at one time.
134138
#ifdef _OPENMP
@@ -183,7 +187,7 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
183187
outp[iz] = inp[iz];
184188
}
185189
}
186-
190+
}
187191
#endif
188192
//ModuleBase::timer::tick(this->classname, "gathers_scatterp");
189193
return;

source/module_basis/module_pw/pw_init.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,25 @@ namespace ModulePW
66
#ifdef __MPI
77
void PW_Basis:: initmpi(
88
const int poolnproc_in,
9-
const int poolrank_in,
10-
MPI_Comm pool_world_in
9+
const int poolrank_in,
10+
MPI_Comm pool_world_in
1111
)
1212
{
13-
this->poolnproc = poolnproc_in;
14-
this->poolrank = poolrank_in;
15-
this->pool_world = pool_world_in;
13+
if (mpi_flag_)
14+
{
15+
this->poolnproc = poolnproc_in;
16+
this->poolrank = poolrank_in;
17+
this->pool_world = pool_world_in;
18+
}else
19+
{
20+
ModuleBase::WARNING_QUIT("PW_Basis","to use MPI_ FFT, please set the mpi_flag as true");
21+
}
1622
}
1723
#endif
18-
24+
void PW_Basis::set_mpi(const bool mpi_flag_in)
25+
{
26+
this->mpi_flag_ = mpi_flag_in;
27+
}
1928
///
2029
/// Init the grids for FFT
2130
/// Input: lattice vectors of the cell, Energy cut off for G^2/2
@@ -87,7 +96,10 @@ void PW_Basis:: initgrids(
8796
ibox[1] = 2*n2+1;
8897
ibox[2] = 2*n3+1;
8998
#ifdef __MPI
90-
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , this->pool_world);
99+
if (mpi_flag_)
100+
{
101+
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , this->pool_world);
102+
}
91103
#endif
92104

93105
// Find the minimal FFT box size the factors into the primes (2,3,5,7).
@@ -201,7 +213,10 @@ void PW_Basis:: initgrids(
201213
}
202214
}
203215
#ifdef __MPI
204-
MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world);
216+
if (mpi_flag_)
217+
{
218+
MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world);
219+
}
205220
#endif
206221
this->gridecut_lat -= 1e-6;
207222

0 commit comments

Comments
 (0)