Skip to content

Commit cedaae6

Browse files
committed
change the mpi_flag
1 parent e5f5a42 commit cedaae6

File tree

11 files changed

+74
-67
lines changed

11 files changed

+74
-67
lines changed

source/module_basis/module_pw/pw_basis.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,6 @@ class PW_Basis
436436

437437
std::string device = "cpu"; ///< cpu or gpu
438438
std::string precision = "double"; ///< single, double, mixing
439-
bool mpi_flag_ = false; ///< ture,is use mpi or not
440439
bool double_data_ = true; ///< if has double data
441440
bool float_data_ = false; ///< if has float data
442441
};

source/module_basis/module_pw/pw_basis_big.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,9 @@ class PW_Basis_Big : public PW_Basis_Sup
166166
ibox[0] = 2*n1+1;
167167
ibox[1] = 2*n2+1;
168168
ibox[2] = 2*n3+1;
169-
if (mpi_flag_)
170-
{
171169
#ifdef __MPI
172170
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , this->pool_world);
173171
#endif
174-
}
175172

176173
// Find the minimal FFT box size the factors into the primes (2,3,5,7).
177174
for (int i = 0; i < 3; i++)
@@ -352,12 +349,9 @@ class PW_Basis_Big : public PW_Basis_Sup
352349
}
353350
}
354351
}
355-
if (mpi_flag_)
356-
{
357352
#ifdef __MPI
358353
MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world);
359354
#endif
360-
}
361355
this->gridecut_lat -= 1e-6;
362356

363357
delete[] ibox;

source/module_basis/module_pw/pw_basis_sup.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
113113
// calculate this->nstot and this->npwtot, liy, riy
114114
this->count_pw_st(st_length2D, st_bottom2D);
115115
}
116-
if (mpi_flag_)
117-
{
118116
#ifdef __MPI
119117

120118
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
@@ -124,14 +122,11 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
124122
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
125123
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
126124
#endif
127-
}
128125
delete[] this->istot2ixy;
129126
this->istot2ixy = new int[this->nstot];
130127

131128
if (poolrank == 0)
132129
{
133-
if (mpi_flag_)
134-
{
135130
#ifdef __MPI
136131
// Parallel line
137132
// (2) Collect the x, y indexs, and length of the sticks.
@@ -153,8 +148,7 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
153148
// We do not need startnsz_per after it.
154149
delete[] this->startnsz_per;
155150
this->startnsz_per = nullptr;
156-
#endif
157-
}else{
151+
#else
158152
// Serial line
159153
// get nst_per, npw_per, fftixy2ip, and istot2ixy
160154
this->nst_per[0] = this->nstot;
@@ -170,9 +164,8 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
170164
}
171165
}
172166
}
167+
#endif
173168
}
174-
if (mpi_flag_)
175-
{
176169
#ifdef __MPI
177170

178171
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
@@ -182,7 +175,6 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
182175
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0, this->pool_world);
183176
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0, this->pool_world);
184177
#endif
185-
}
186178
this->npw = this->npw_per[this->poolrank];
187179
this->nst = this->nst_per[this->poolrank];
188180
this->nstnz = this->nst * this->nz;

source/module_basis/module_pw/pw_distributeg_method1.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ void PW_Basis::distribution_method1()
4646

4747
this->count_pw_st(st_length2D, st_bottom2D);
4848
}
49-
if (mpi_flag_)
50-
{
5149
#ifdef __MPI
5250
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
5351
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
@@ -56,13 +54,11 @@ void PW_Basis::distribution_method1()
5654
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
5755
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
5856
#endif
59-
}
60-
delete[] this->istot2ixy; this->istot2ixy = new int[this->nstot];
57+
delete[] this->istot2ixy;
58+
this->istot2ixy = new int[this->nstot];
6159

6260
if(poolrank == 0)
6361
{
64-
if (mpi_flag_)
65-
{
6662
#ifdef __MPI
6763
// Parallel line
6864
// (2) Collect the x, y indexs, and length of the sticks.
@@ -84,8 +80,7 @@ void PW_Basis::distribution_method1()
8480
//We do not need startnsz_per after it.
8581
delete[] this->startnsz_per;
8682
this->startnsz_per=nullptr;
87-
#endif
88-
}else{
83+
#else
8984
// Serial line
9085
// get nst_per, npw_per, fftixy2ip, and istot2ixy
9186
this->nst_per[0] = this->nstot;
@@ -100,10 +95,8 @@ void PW_Basis::distribution_method1()
10095
st_move++;
10196
}
10297
}
98+
#endif
10399
}
104-
}
105-
if (mpi_flag_)
106-
{
107100
#ifdef __MPI
108101
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
109102
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
@@ -112,7 +105,6 @@ void PW_Basis::distribution_method1()
112105
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
113106
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
114107
#endif
115-
}
116108
this->npw = this->npw_per[this->poolrank];
117109
this->nst = this->nst_per[this->poolrank];
118110
this->nstnz = this->nst * this->nz;

source/module_basis/module_pw/pw_distributeg_method2.cpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ void PW_Basis::distribution_method2()
4646
*/
4747
this->count_pw_st(st_length2D, st_bottom2D);
4848
}
49-
if (mpi_flag_)
50-
{
5149
#ifdef __MPI
5250
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
5351
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
@@ -56,13 +54,10 @@ void PW_Basis::distribution_method2()
5654
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
5755
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
5856
#endif
59-
}
6057
delete[] this->istot2ixy; this->istot2ixy = new int[this->nstot];
6158

6259
if(poolrank == 0)
6360
{
64-
if (mpi_flag_)
65-
{
6661
#ifdef __MPI
6762

6863
// Parallel line
@@ -77,8 +72,7 @@ void PW_Basis::distribution_method2()
7772
//We do not need startnsz_per after it.
7873
delete[] this->startnsz_per;
7974
this->startnsz_per=nullptr;
80-
#endif
81-
}else{
75+
#else
8276
// Serial line
8377
// get nst_per, npw_per, fftixy2ip, and istot2ixy
8478
this->nst_per[0] = this->nstot;
@@ -93,18 +87,15 @@ void PW_Basis::distribution_method2()
9387
st_move++;
9488
}
9589
}
96-
}
90+
#endif
9791
}
9892
#ifdef __MPI
99-
if (mpi_flag_)
100-
{
101-
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
102-
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
103-
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
104-
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
105-
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
106-
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
107-
}
93+
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
94+
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
95+
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
96+
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
97+
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0, this->pool_world);
98+
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0, this->pool_world);
10899
#endif
109100
this->npw = this->npw_per[this->poolrank];
110101
this->nst = this->nst_per[this->poolrank];

source/module_basis/module_pw/pw_gatherscatter.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
3636
return;
3737
}
3838
#ifdef __MPI
39-
if (mpi_flag_)
40-
{
4139
//change (nplane fftnxy) to (nplane,nstot)
4240
// Hence, we can send them at one time.
4341
#ifdef _OPENMP
@@ -85,7 +83,6 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
8583
}
8684
}
8785
}
88-
}
8986
#endif
9087
//ModuleBase::timer::tick(this->classname, "gatherp_scatters");
9188
return;
@@ -131,8 +128,6 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
131128
return;
132129
}
133130
#ifdef __MPI
134-
if (mpi_flag_)
135-
{
136131
// change (nz,ns) to (numz[ip],ns, poolnproc)
137132
// Hence, we can send them at one time.
138133
#ifdef _OPENMP
@@ -186,7 +181,6 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
186181
{
187182
outp[iz] = inp[iz];
188183
}
189-
}
190184
}
191185
#endif
192186
//ModuleBase::timer::tick(this->classname, "gathers_scatterp");

source/module_basis/module_pw/pw_init.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ void PW_Basis:: initmpi(
1313
this->poolnproc = poolnproc_in;
1414
this->poolrank = poolrank_in;
1515
this->pool_world = pool_world_in;
16-
mpi_flag_ = ((this->poolnproc>1) || !(this->pool_world == MPI_COMM_NULL));
1716
}
1817
#endif
1918
///
@@ -87,10 +86,7 @@ void PW_Basis:: initgrids(
8786
ibox[1] = 2*n2+1;
8887
ibox[2] = 2*n3+1;
8988
#ifdef __MPI
90-
if (mpi_flag_)
91-
{
92-
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , this->pool_world);
93-
}
89+
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , this->pool_world);
9490
#endif
9591

9692
// Find the minimal FFT box size the factors into the primes (2,3,5,7).
@@ -204,10 +200,7 @@ void PW_Basis:: initgrids(
204200
}
205201
}
206202
#ifdef __MPI
207-
if (mpi_flag_)
208-
{
209-
MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world);
210-
}
203+
MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world);
211204
#endif
212205
this->gridecut_lat -= 1e-6;
213206

source/module_basis/module_pw/test_gpu/pw_basis_C2C.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@ TEST_F(PWTEST, recip_to_real_C2C_double)
2424
bool xprime = false;
2525

2626
// init
27-
pwtest.initmpi(1,0,MPI_COMM_NULL);
27+
const int mypool = 0;
28+
const int key = 1;
29+
const int nproc_in_pool = 1;
30+
const int rank_in_pool = 0;
31+
MPI_Comm POOL_WORLD;
32+
MPI_Comm_split(MPI_COMM_WORLD,mypool,key,&POOL_WORLD);
33+
pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD);
2834
pwtest.initgrids(lat0, latvec, wfcecut);
2935
pwtest.initparameters(gamma_only, wfcecut, distribution_type, xprime);
3036
pwtest.setuptransform();
@@ -168,7 +174,14 @@ TEST_F(PWTEST, recip_to_real_C2C_float)
168174
int distribution_type = 1;
169175
bool xprime = false;
170176

171-
pwtest.initmpi(1,0,MPI_COMM_NULL);
177+
const int mypool = 0;
178+
const int key = 1;
179+
const int nproc_in_pool = 1;
180+
const int rank_in_pool = 0;
181+
MPI_Comm POOL_WORLD;
182+
MPI_Comm_split(MPI_COMM_WORLD,mypool,key,&POOL_WORLD);
183+
pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD);
184+
172185
pwtest.initgrids(lat0, latvec, wfcecut);
173186
pwtest.initparameters(gamma_only, wfcecut, distribution_type, xprime);
174187
pwtest.setuptransform();

source/module_basis/module_pw/test_gpu/pw_basis_C2R.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,14 @@ TEST_F(PWTEST, recip_to_real_double)
2727
bool xprime = false;
2828

2929
// init
30-
pwtest.initmpi(1,0,MPI_COMM_NULL);
30+
const int mypool = 0;
31+
const int key = 1;
32+
const int nproc_in_pool = 1;
33+
const int rank_in_pool = 0;
34+
MPI_Comm POOL_WORLD;
35+
MPI_Comm_split(MPI_COMM_WORLD,mypool,key,&POOL_WORLD);
36+
pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD);
37+
3138
pwtest.initgrids(lat0, latvec, wfcecut);
3239
pwtest.initparameters(gamma_only, wfcecut, distribution_type, xprime);
3340
pwtest.setuptransform();
@@ -171,7 +178,14 @@ TEST_F(PWTEST, recip_to_real_float)
171178
int distribution_type = 1;
172179
bool xprime = false;
173180

174-
pwtest.initmpi(1,0,MPI_COMM_NULL);
181+
const int mypool = 0;
182+
const int key = 1;
183+
const int nproc_in_pool = 1;
184+
const int rank_in_pool = 0;
185+
MPI_Comm POOL_WORLD;
186+
MPI_Comm_split(MPI_COMM_WORLD,mypool,key,&POOL_WORLD);
187+
pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD);
188+
175189
pwtest.initgrids(lat0, latvec, wfcecut);
176190
pwtest.initparameters(gamma_only, wfcecut, distribution_type, xprime);
177191
pwtest.setuptransform();

source/module_basis/module_pw/test_gpu/pw_basis_k_C2C.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,14 @@ TEST_F(PWTEST, pw_basis_k_C2C_double)
2525
bool xprime = false;
2626
//--------------------------------------------------
2727
// init //real parameter
28-
pwtest.initmpi(1,0,MPI_COMM_NULL);
28+
const int mypool = 0;
29+
const int key = 1;
30+
const int nproc_in_pool = 1;
31+
const int rank_in_pool = 0;
32+
MPI_Comm POOL_WORLD;
33+
MPI_Comm_split(MPI_COMM_WORLD,mypool,key,&POOL_WORLD);
34+
pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD);
35+
2936
pwtest.initgrids(lat0, latvec, 4 * wfcecut);
3037
pwtest.initparameters(gamma_only, wfcecut, nks, kvec_d, distribution_type, xprime);
3138
pwtest.setuptransform();
@@ -171,7 +178,14 @@ TEST_F(PWTEST, pw_basis_k_C2C_float)
171178
bool xprime = false;
172179
//--------------------------------------------------
173180
// init //real parameter
174-
pwtest.initmpi(1,0,MPI_COMM_NULL);
181+
const int mypool = 0;
182+
const int key = 1;
183+
const int nproc_in_pool = 1;
184+
const int rank_in_pool = 0;
185+
MPI_Comm POOL_WORLD;
186+
MPI_Comm_split(MPI_COMM_WORLD,mypool,key,&POOL_WORLD);
187+
pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD);
188+
175189
pwtest.initgrids(lat0, latvec, 4 * wfcecut);
176190
pwtest.initparameters(gamma_only, wfcecut, nks, kvec_d, distribution_type, xprime);
177191
pwtest.setuptransform();

0 commit comments

Comments
 (0)