Skip to content

Commit db2350b

Browse files
committed
< prof >
Rewrite gatterscatter function. Now it has higher parallel efficiency.
1 parent 767f3ce commit db2350b

File tree

8 files changed

+204
-112
lines changed

8 files changed

+204
-112
lines changed

source/module_base/global_function.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void DONE(std::ofstream &ofs,const std::string &description, bool only_rank0 = f
116116
template<class T, class TI>
117117
inline void ZEROS(std::complex<T> *u,const TI n) // Peize Lin change int to TI at 2020.03.03
118118
{
119-
if(n <= 0) return;
119+
assert(n>=0);
120120
for (TI i=0;i<n;i++)
121121
{
122122
u[i] = std::complex<T>(0.0,0.0);

source/module_pw/fft.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ FFT::~FFT()
3737
#endif
3838
}
3939

40-
void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, bool gamma_only_in, bool mpifft_in)
40+
void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool mpifft_in)
4141
{
4242
this->gamma_only = gamma_only_in;
4343
this->nx = nx_in;
@@ -47,14 +47,16 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
4747
this->nz = nz_in;
4848
this->ns = ns_in;
4949
this->nplane = nplane_in;
50+
this->nproc = nproc_in;
5051
this->mpifft = mpifft_in;
5152
this->nxy = this->nx * this-> ny;
5253
this->bignxy = this->nx * this->bigny;
5354
this->maxgrids = (this->nz * this->ns > this->bignxy * nplane) ? this->nz * this->ns : this->bignxy * nplane;
5455
if(!this->mpifft)
5556
{
5657
//It seems in-place fft is faster than out-of-place fft
57-
c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
58+
if(this->nproc == 1) c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
59+
else c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
5860
//c_gspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
5961
if(this->gamma_only)
6062
{
@@ -69,8 +71,8 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
6971
}
7072
else
7173
{
72-
c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
73-
//c_rspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
74+
if(this->nproc == 1) c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
75+
else c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
7476
}
7577
#ifdef __MIX_PRECISION
7678
cf_gspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->nz * this->ns);

source/module_pw/fft.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class FFT
2727

2828
FFT();
2929
~FFT();
30-
void initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, bool gamma_only_in, bool mpifft_in = false);
30+
void initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool mpifft_in = false);
3131
void setupFFT();
3232
void cleanFFT();
3333

@@ -52,6 +52,7 @@ class FFT
5252
int ns; //number of sticks
5353
int nplane; //number of x-y planes
5454
int maxgrids; // max between nz * ns and bignxy * nplane
55+
int nproc; // number of proc.
5556
std::complex<double> * c_gspace, *c_gspace2; //complex number space for g, [ns * nz]
5657
std::complex<double> * c_rspace, *c_rspace2;//complex number space for r, [nplane * nx *ny]
5758
double *r_rspace; //real number space for r, [nplane * nx *ny]

source/module_pw/pw_basis.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@ PW_Basis::PW_Basis()
1414
ixy2ip = NULL;
1515
startnsz_per = NULL;
1616
nstnz_per = NULL;
17+
nst_per = NULL;
1718
gdirect = NULL;
1819
gcar = NULL;
1920
gg = NULL;
2021
startz = NULL;
2122
numz = NULL;
23+
this->numg = NULL;
24+
this->startg = NULL;
25+
this->startr = NULL;
26+
this->numr = NULL;
2227
poolnproc = 1;
2328
poolrank = 0;
2429
}
@@ -32,21 +37,62 @@ PW_Basis:: ~PW_Basis()
3237
if(ixy2ip != NULL) delete[] ixy2ip;
3338
if(startnsz_per != NULL) delete[] startnsz_per;
3439
if(nstnz_per != NULL) delete[] nstnz_per;
40+
if(nst_per != NULL) delete[] nst_per;
3541
if(gdirect != NULL) delete[] gdirect;
3642
if(gcar != NULL) delete[] gcar;
3743
if(gg != NULL) delete[] gg;
3844
if(startz != NULL) delete[] startz;
3945
if(numz != NULL) delete[] numz;
46+
if(numg != NULL) delete[] numg;
47+
if(numr != NULL) delete[] numr;
48+
if(startg != NULL) delete[] startg;
49+
if(startr != NULL) delete[] startr;
4050
}
4151

4252
void PW_Basis::setuptransform()
4353
{
4454
this->distribute_r();
4555
this->distribute_g();
46-
this->ft.initfft(this->nx,this->bigny,this->nz,this->nst,this->nplane,this->gamma_only);
56+
this->getstartgr();
57+
this->ft.initfft(this->nx,this->bigny,this->nz,this->nst,this->nplane,this->poolnproc,this->gamma_only);
4758
this->ft.setupFFT();
4859
}
4960

61+
void PW_Basis::getstartgr()
62+
{
63+
this->maxgrids = (this->nz * this->nst > this->bignxy * nplane) ? this->nz * this->nst : this->bignxy * nplane;
64+
65+
//---------------------------------------------
66+
// sum : starting plane of FFT box.
67+
//---------------------------------------------
68+
this->numg = new int[poolnproc];
69+
this->startg = new int[poolnproc];
70+
this->startr = new int[poolnproc];
71+
this->numr = new int[poolnproc];
72+
73+
// Each processor has a set of full sticks,
74+
// 'rank_use' processor send a piece(npps[ip]) of these sticks(nst_per[rank_use])
75+
// to all the other processors in this pool
76+
for (int ip = 0;ip < poolnproc; ++ip) this->numg[ip] = this->nst_per[poolrank] * this->numz[ip];
77+
78+
79+
// Each processor in a pool send a piece of each stick(nst_per[ip]) to
80+
// other processors in this pool
81+
// rank_use processor receive datas in npps[rank_p] planes.
82+
for (int ip = 0;ip < poolnproc; ++ip) this->numr[ip] = this->nst_per[ip] * this->numz[poolrank];
83+
84+
85+
// startg record the starting 'numg' position in each processor.
86+
this->startg[0] = 0;
87+
for (int ip = 1;ip < poolnproc; ++ip) this->startg[ip] = this->startg[ip-1] + this->numg[ip-1];
88+
89+
90+
// startr record the starting 'numr' position
91+
this->startr[0] = 0;
92+
for (int ip = 1;ip < poolnproc; ++ip) this->startr[ip] = this->startr[ip-1] + this->numr[ip-1];
93+
return;
94+
}
95+
5096
//
5197
// Collect planewaves on current core, and construct gg, gdirect, gcar according to ig2isz and is2ixy.
5298
// is2ixy contains the x-coordinate and y-coordinate of sticks on current core.

source/module_pw/pw_basis.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class PW_Basis
5555
// only on first proc.
5656
int *startnsz_per; // startnsz_per[ip]: starting is * nz stick in the ip^th proc.
5757
int *nstnz_per; // nz * nst(number of sticks) on each core.
58+
int *nst_per;// nst on each core
5859
// on all proc.
5960
int *ig2isz; // map ig to (is, iz).
6061
int *istot2ixy; // istot2ixy[is]: ix + iy * nx of is^th stick among all sticks.
@@ -69,6 +70,10 @@ class PW_Basis
6970
int nrxx; //num. of real space grids
7071
int *startz; //startz[ip]: starting z plane in the ip-th proc. in current POOL_WORLD
7172
int *numz; //numz[ip]: num. of z planes in the ip-th proc. in current POOL_WORLD
73+
int *numg; //numg[ip] : nst_per[poolrank] * numz[ip]
74+
int *numr; //numr[ip] : numz[poolrank] * nst_per[ip]
75+
int *startg; // startg[ip] = numg[ip-1] + startg[ip-1]
76+
int *startr; // startr[ip] = numr[ip-1] + startr[ip-1]
7277
int nplane; //num. of planes in current proc.
7378

7479
ModuleBase::Vector3<double> *gdirect; //(= *G1d) ; // ig = new Vector igc[ngmc]
@@ -86,6 +91,8 @@ class PW_Basis
8691
//distribute real-space grids to different processors
8792
void distribute_r();
8893

94+
void getstartgr();
95+
8996
//distribute plane waves to different processors
9097
void distribution_method1(); // x varies fast
9198
void distribution_method2(); // sticks sorted according to ixy
@@ -135,8 +142,7 @@ class PW_Basis
135142
int* st_i, // x or x + nx (if x < 0) of stick.
136143
int* st_j, // y or y + ny (if y < 0) of stick.
137144
int* st_length, // the stick on (x, y) consists of st_length[x*ny+y] planewaves.
138-
int* npw_per, // number of planewaves on each core.
139-
int* nst_per // number of sticks on each core.
145+
int* npw_per // number of planewaves on each core.
140146
);
141147
void get_istot2ixy(
142148
int* st_i, // x or x + nx (if x < 0) of stick.
@@ -180,6 +186,7 @@ class PW_Basis
180186
// FFT dimensions for wave functions.
181187
int nx, ny, nz, nxyz, nxy;
182188
int bigny, bignxyz, bignxy; // Gamma_only: ny = int(bigny/2)-1 , others: ny = bigny
189+
int maxgrids; // max between nz * ns and bignxy * nplane
183190
FFT ft;
184191

185192
void real2recip(double * in, std::complex<double> * out); //in:(nplane,nx*ny) ; out(nz, ns)
@@ -189,6 +196,8 @@ class PW_Basis
189196

190197
void gatherp_scatters(std::complex<double> *in, std::complex<double> *out); //gather planes and scatter sticks of all processors
191198
void gathers_scatterp(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
199+
void gathers_scatterp2(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
200+
void gatherp_scatters2(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
192201
void gatherp_scatters_gamma(std::complex<double> *in, std::complex<double> *out); //gather planes and scatter sticks of all processors, used when gamma_only
193202
void gathers_scatterp_gamma(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors, used when gamma only
194203

source/module_pw/pw_distributeg_method1.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void PW_Basis::distribution_method1()
3030
this->nstot = 0; // total number of sticks.
3131
int *st_bottom2D = NULL; // st_bottom2D[ixy], minimum z of stick on (x, y).
3232
int *st_length2D = NULL; // st_length2D[ixy], number of planewaves in stick on (x, y).
33-
33+
this->nst_per = new int[this->poolnproc]; // number of sticks on each core.
3434
if (poolrank == 0)
3535
{
3636
// (1) Count the total number of planewaves (tot_npw) and sticks (this->nstot).
@@ -73,11 +73,10 @@ void PW_Basis::distribution_method1()
7373

7474
// (3) Distribute sticks to cores.
7575
int *npw_per = new int[this->poolnproc]; // number of planewaves on each core.
76-
int *nst_per = new int[this->poolnproc]; // number of sticks on each core.
7776
this->nstnz_per = new int[this->poolnproc]; // nz * nst(number of sticks) on each core.
7877
this->startnsz_per = new int[this->poolnproc];
7978
ModuleBase::GlobalFunc::ZEROS(npw_per, poolnproc);
80-
ModuleBase::GlobalFunc::ZEROS(nst_per, poolnproc);
79+
ModuleBase::GlobalFunc::ZEROS(this->nst_per, poolnproc);
8180
ModuleBase::GlobalFunc::ZEROS(this->nstnz_per, poolnproc);
8281
ModuleBase::GlobalFunc::ZEROS(startnsz_per, poolnproc);
8382

@@ -86,7 +85,7 @@ void PW_Basis::distribution_method1()
8685
{
8786
this->ixy2ip[ixy] = -1; // meaning this stick has not been distributed or there is no stick on (x, y).
8887
}
89-
this->divide_sticks(st_i, st_j, st_length, npw_per, nst_per);
88+
this->divide_sticks(st_i, st_j, st_length, npw_per);
9089
delete[] st_length;
9190

9291
// for test -----------------------------------------------------------------------------
@@ -113,7 +112,6 @@ void PW_Basis::distribution_method1()
113112
MPI_Send(&nst_per[ip], 1, MPI_INT, ip, 0, POOL_WORLD);
114113
}
115114
delete[] npw_per;
116-
delete[] nst_per;
117115

118116
#else
119117
// Serial line
@@ -172,6 +170,7 @@ void PW_Basis::distribution_method1()
172170
MPI_Bcast(this->ixy2ip, this->nxy, MPI_INT, 0, POOL_WORLD);
173171
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, POOL_WORLD);
174172
MPI_Bcast(this->ixy2istot, this->nxy, MPI_INT, 0, POOL_WORLD);
173+
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , POOL_WORLD);
175174

176175
std::cout << "Bcast done\n";
177176
#endif
@@ -304,8 +303,7 @@ void PW_Basis::divide_sticks(
304303
int* st_i, // x or x + nx (if x < 0) of stick.
305304
int* st_j, // y or y + ny (if y < 0) of stick.
306305
int* st_length, // the stick on (x, y) consists of st_length[x*ny+y] planewaves.
307-
int* npw_per, // number of planewaves on each core.
308-
int* nst_per // number of sticks on each core.
306+
int* npw_per // number of planewaves on each core.
309307
)
310308
{
311309
int ipmin = 0; // The ip of core containing least number of planewaves.
@@ -322,25 +320,25 @@ void PW_Basis::divide_sticks(
322320

323321
if (npw_ip == 0)
324322
{
325-
if (non_zero_grid + nz < this->nrxx) // assert reciprocal planewaves is less than real space planewaves.
326-
{
323+
// if (non_zero_grid + nz < this->nrxx) // assert reciprocal planewaves is less than real space planewaves.
324+
// {
327325
ipmin = ip;
328326
break;
329-
}
327+
// }
330328
}
331329
else if (npw_ip < npwmin)
332330
{
333-
if (non_zero_grid + nz < this->nrxx) // assert reciprocal planewaves is less than real space planewaves.
334-
{
331+
// if (non_zero_grid + nz < this->nrxx) // assert reciprocal planewaves is less than real space planewaves.
332+
// {
335333
ipmin = ip;
336-
}
334+
// }
337335
}
338336
else if (npw_ip == npwmin && nst_ip < nstmin)
339337
{
340-
if (non_zero_grid + nz < this->nrxx)
341-
{
338+
// if (non_zero_grid + nz < this->nrxx)
339+
// {
342340
ipmin = ip;
343-
}
341+
// }
344342
}
345343
}
346344
nst_per[ipmin]++;

0 commit comments

Comments
 (0)