Skip to content

Commit 3d8788c

Browse files
authored
Merge pull request #544 from Qianruipku/planewave
optimize efficiency
2 parents 85e2c05 + 7a61632 commit 3d8788c

File tree

14 files changed

+85
-156
lines changed

14 files changed

+85
-156
lines changed

source/module_pw/fft.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ FFT::~FFT()
3636
#endif
3737
}
3838

39-
void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool mpifft_in)
39+
void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int lix_in, int rix_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool mpifft_in)
4040
{
4141
this->gamma_only = gamma_only_in;
4242
this->nx = nx_in;
@@ -45,6 +45,8 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
4545
else this->ny = this->bigny;
4646
this->nz = nz_in;
4747
this->ns = ns_in;
48+
this->lix = lix_in;
49+
this->rix = rix_in;
4850
this->nplane = nplane_in;
4951
this->nproc = nproc_in;
5052
this->mpifft = mpifft_in;
@@ -330,7 +332,11 @@ void FFT::fftxyfor(std::complex<double>* & in, std::complex<double>* & out)
330332
{
331333
int npy = this->nplane * this-> ny;
332334
fftw_execute_dft( this->planxfor, (fftw_complex *)in, (fftw_complex *)out);
333-
for (int i=0; i<this->nx;++i)
335+
for (int i=0; i<=this->lix;++i)
336+
{
337+
fftw_execute_dft( this->planyfor, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
338+
}
339+
for (int i=this->rix; i<this->nx;++i)
334340
{
335341
fftw_execute_dft( this->planyfor, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
336342
}
@@ -341,7 +347,11 @@ void FFT::fftxybac(std::complex<double>* & in, std::complex<double>* & out)
341347
{
342348
int npy = this->nplane * this-> ny;
343349

344-
for (int i=0; i<this->nx;++i)
350+
for (int i=0; i<=this->lix;++i)
351+
{
352+
fftw_execute_dft( this->planybac, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
353+
}
354+
for (int i=this->rix; i<this->nx;++i)
345355
{
346356
fftw_execute_dft( this->planybac, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
347357
}

source/module_pw/fft.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class FFT
2727

2828
FFT();
2929
~FFT();
30-
void initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool mpifft_in = false);
30+
void initfft(int nx_in, int bigny_in, int nz_in, int lix_in, int rix_in, int ns_in, int nplane_in,
31+
int nproc_in, bool gamma_only_in, bool mpifft_in = false);
3132
void setupFFT();
3233
void cleanFFT();
3334

@@ -55,6 +56,7 @@ class FFT
5556
int nxy;
5657
int bigny;
5758
int bignxy;
59+
int lix,rix;// lix: the left edge of the pw ball; rix: the right edge of the pw ball
5860
int ns; //number of sticks
5961
int nplane; //number of x-y planes
6062
int maxgrids; // max between nz * ns and bignxy * nplane

source/module_pw/pw_basis.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace ModulePW
88
PW_Basis::PW_Basis()
99
{
1010
ig2isz = NULL;
11-
istot2ixy = NULL;
11+
istot2bigixy = NULL;
1212
ixy2istot = NULL;
1313
is2ixy = NULL;
1414
ixy2ip = NULL;
@@ -31,7 +31,7 @@ PW_Basis::PW_Basis()
3131
PW_Basis:: ~PW_Basis()
3232
{
3333
if(ig2isz != NULL) delete[] ig2isz;
34-
if(istot2ixy != NULL) delete[] istot2ixy;
34+
if(istot2bigixy != NULL) delete[] istot2bigixy;
3535
if(ixy2istot != NULL) delete[] ixy2istot;
3636
if(is2ixy != NULL) delete[] is2ixy;
3737
if(ixy2ip != NULL) delete[] ixy2ip;
@@ -54,7 +54,7 @@ void PW_Basis::setuptransform()
5454
this->distribute_r();
5555
this->distribute_g();
5656
this->getstartgr();
57-
this->ft.initfft(this->nx,this->bigny,this->nz,this->nst,this->nplane,this->poolnproc,this->gamma_only);
57+
this->ft.initfft(this->nx,this->bigny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only);
5858
this->ft.setupFFT();
5959
}
6060

source/module_pw/pw_basis.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class PW_Basis
5858
int *nst_per;// nst on each core
5959
// on all proc.
6060
int *ig2isz; // map ig to (is, iz).
61-
int *istot2ixy; // istot2ixy[is]: ix + iy * nx of is^th stick among all sticks.
61+
int *istot2bigixy; // istot2bigixy[is]: iy + ix * bigny of is^th stick among all sticks.
6262
int *ixy2istot; // ixy2istot[ix + iy * nx]: is of stick on (ix, iy) among all sticks.
6363
int *is2ixy; // is2ixy[is]: ix + iy * bignx of is^th stick among sticks on current proc.
6464
int *ixy2ip; // ixy2ip[ix + iy * nx]: ip of proc which contains stick on (ix, iy).
@@ -144,14 +144,12 @@ class PW_Basis
144144
int* st_length, // the stick on (x, y) consists of st_length[x*ny+y] planewaves.
145145
int* npw_per // number of planewaves on each core.
146146
);
147-
void get_istot2ixy(
147+
void get_istot2bigixy(
148148
int* st_i, // x or x + nx (if x < 0) of stick.
149149
int* st_j // y or y + ny (if y < 0) of stick.
150150
);
151151
// for distributeg_method2
152-
void divide_sticks2(
153-
int* nst_per // number of sticks on each core.
154-
);
152+
void divide_sticks2();
155153
void create_maps(
156154
int* st_length2D, // the number of planewaves that belong to the stick located on (x, y), stored in 2d x-y plane.
157155
int* npw_per // number of planewaves on each core.
@@ -186,6 +184,7 @@ class PW_Basis
186184
// FFT dimensions for wave functions.
187185
int nx, ny, nz, nxyz, nxy;
188186
int bigny, bignxyz, bignxy; // Gamma_only: ny = int(bigny/2)-1 , others: ny = bigny
187+
int lix,rix;// lix: the left edge of the pw ball; rix: the right edge of the pw ball
189188
int maxgrids; // max between nz * ns and bignxy * nplane
190189
FFT ft;
191190

@@ -196,8 +195,8 @@ class PW_Basis
196195

197196
void gatherp_scatters(std::complex<double> *in, std::complex<double> *out); //gather planes and scatter sticks of all processors
198197
void gathers_scatterp(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
199-
void gathers_scatterp2(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
200-
void gatherp_scatters2(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
198+
// void gathers_scatterp2(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
199+
// void gatherp_scatters2(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
201200
void gatherp_scatters_gamma(std::complex<double> *in, std::complex<double> *out); //gather planes and scatter sticks of all processors, used when gamma_only
202201
void gathers_scatterp_gamma(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors, used when gamma only
203202

source/module_pw/pw_distributeg.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ void PW_Basis::count_pw_st(
5454
iy_start = 0;
5555
iy_end = this->ny - 1;
5656
}
57+
this->lix = this->rix = 0;
5758
for (int ix = -ibox[0]; ix <= ibox[0]; ++ix)
5859
{
5960
for (int iy = iy_start; iy <= iy_end; ++iy)
@@ -81,6 +82,8 @@ void PW_Basis::count_pw_st(
8182
if (length == 0) st_bottom2D[index] = iz; // length == 0 means this point is the bottom of stick (x, y).
8283
++tot_npw;
8384
++length;
85+
if(ix < this->rix) this->rix = ix;
86+
if(ix > this->lix) this->lix = ix;
8487
}
8588
}
8689
if (length > 0)
@@ -90,6 +93,8 @@ void PW_Basis::count_pw_st(
9093
}
9194
}
9295
}
96+
if(rix <= 0) rix += this->nx;
97+
std::cout<<"lix "<<lix<<" ; rix "<<rix<<std::endl;
9398
return;
9499
}
95100

source/module_pw/pw_distributeg_method1.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace ModulePW
2020
//
2121
// Secondly, distribute these sticks to coreors.
2222
//Known: G, GT, GGT, ny, nx, nz, poolnproc, poolrank, ggecut
23-
//output: ig2isz[ig], istot2ixy[is], ixy2istot[nxy], is2ixy[is], ixy2ip[ixy], startnsz_per[ip], nst_per[ip], nst
23+
//output: ig2isz[ig], istot2bigixy[is], ixy2istot[nxy], is2ixy[is], ixy2ip[ixy], startnsz_per[ip], nst_per[ip], nst
2424
//
2525
void PW_Basis::distribution_method1()
2626
{
@@ -95,7 +95,7 @@ void PW_Basis::distribution_method1()
9595
std::cout << std::endl;
9696
// --------------------------------------------------------------------------------------
9797

98-
this->get_istot2ixy(st_i, st_j);
98+
this->get_istot2bigixy(st_i, st_j);
9999
delete[] st_i;
100100
delete[] st_j;
101101
// for test -----------------------------------------------------------------------------
@@ -124,15 +124,15 @@ void PW_Basis::distribution_method1()
124124
this->startnsz_per[0] = 0;
125125

126126
this->ixy2istot = new int[nxy];
127-
this->istot2ixy = new int[this->nstot];
127+
this->istot2bigixy = new int[this->nstot];
128128
this->ixy2ip = new int[nxy]; // ip of core which contains stick on (x, y).
129129
int st_move = 0;
130130
for (int ixy = 0; ixy < nxy; ++ixy)
131131
{
132132
if (st_length2D[ixy] > 0)
133133
{
134134
this->ixy2istot[ixy] = st_move;
135-
this->istot2ixy[st_move] = ixy;
135+
this->istot2bigixy[st_move] = ixy / ny * bigny + ixy % ny;
136136
this->ixy2ip[ixy] = 0;
137137
st_move++;
138138
}
@@ -156,19 +156,21 @@ void PW_Basis::distribution_method1()
156156
#ifdef __MPI
157157
MPI_Bcast(&tot_npw, 1, MPI_INT, 0, POOL_WORLD);
158158
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, POOL_WORLD);
159+
MPI_Bcast(&lix, 1, MPI_INT, 0, POOL_WORLD);
160+
MPI_Bcast(&rix, 1, MPI_INT, 0, POOL_WORLD);
159161
if (this->poolrank != 0)
160162
{
161163
st_bottom2D = new int[this->nxy]; // minimum z of stick.
162164
st_length2D = new int[this->nxy]; // number of planewaves in stick.
163165
this->ixy2ip = new int[this->nxy]; // ip of core which contains stick on (x, y).
164166
this->ixy2istot = new int[this->nxy];
165-
this->istot2ixy = new int[this->nstot];
167+
this->istot2bigixy = new int[this->nstot];
166168
}
167169

168170
MPI_Bcast(st_length2D, this->nxy, MPI_INT, 0, POOL_WORLD);
169171
MPI_Bcast(st_bottom2D, this->nxy, MPI_INT, 0, POOL_WORLD);
170172
MPI_Bcast(this->ixy2ip, this->nxy, MPI_INT, 0, POOL_WORLD);
171-
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, POOL_WORLD);
173+
MPI_Bcast(this->istot2bigixy, this->nstot, MPI_INT, 0, POOL_WORLD);
172174
MPI_Bcast(this->ixy2istot, this->nxy, MPI_INT, 0, POOL_WORLD);
173175
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , POOL_WORLD);
174176

@@ -265,7 +267,6 @@ void PW_Basis::collect_st(
265267
}
266268
}
267269
assert(is == this->nstot);
268-
269270
std::cout<<"collect sticks done\n";
270271

271272
// As we will distribute the longest sticks preferentially in Step(3), we rearrange st_* in the order of length decreasing.
@@ -362,33 +363,33 @@ void PW_Basis::divide_sticks(
362363
//
363364
// (3-2) Rearrange sticks in the order of the ip of core increasing, in each core, sticks are sorted in the order of ixy increasing.
364365
// (st_start + st_move) is the new index of sticks.
365-
// Then get istot2ixy (istot2ixy[is]: iy + ix * ny of is^th stick among all sticks) on the first core
366+
// Then get istot2bigixy (istot2bigixy[is]: iy + ix * ny of is^th stick among all sticks) on the first core
366367
// and ixy2istot (ixy2istot[iy + ix * ny]: is of stick on (iy, ix) among all sticks).
367368
// known: this->nstot, st_i, st_j, this->startnsz_per
368-
// output: istot2ixy, ixy2istot
369+
// output: istot2bigixy, ixy2istot
369370
//
370-
void PW_Basis::get_istot2ixy(
371+
void PW_Basis::get_istot2bigixy(
371372
int* st_i, // x or x + nx (if x < 0) of stick.
372373
int* st_j // y or y + ny (if y < 0) of stick.
373374
)
374375
{
375376
assert(this->poolrank == 0);
376377
this->ixy2istot = new int[this->nx * this->ny];
377-
this->istot2ixy = new int[this->nstot];
378+
this->istot2bigixy = new int[this->nstot];
378379
int* st_move = new int[this->poolnproc]; // st_move[ip]: this is the st_move^th stick on ip^th core.
379380
for (int ixy = 0; ixy < this->nx * this->ny; ++ixy)
380381
{
381382
this->ixy2istot[ixy] = -1;
382383
}
383-
ModuleBase::GlobalFunc::ZEROS(this->istot2ixy, this->nstot);
384+
ModuleBase::GlobalFunc::ZEROS(this->istot2bigixy, this->nstot);
384385
ModuleBase::GlobalFunc::ZEROS(st_move, this->poolnproc);
385386

386387
for (int ixy = 0; ixy < this->nxy; ++ixy)
387388
{
388389
int ip = this->ixy2ip[ixy];
389390
if (ip != -1)
390391
{
391-
this->istot2ixy[this->startnsz_per[ip] / this->nz + st_move[ip]] = ixy;
392+
this->istot2bigixy[this->startnsz_per[ip] / this->nz + st_move[ip]] = (ixy / ny)*bigny + ixy % ny;
392393
this->ixy2istot[ixy] = this->startnsz_per[ip] / this->nz + st_move[ip];
393394
st_move[ip]++;
394395
}

source/module_pw/pw_distributeg_method2.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace ModulePW
1919
// ip 0 1 ...
2020
//
2121
//Known: G, GT, GGT, ny, nx, nz, poolnproc, poolrank, ggecut
22-
//output: ig2isz[ig], istot2ixy[is], ixy2istot[nxy], is2ixy[is], ixy2ip[ixy], startnsz_per[ip], nst_per[ip], nst
22+
//output: ig2isz[ig], istot2bigixy[is], ixy2istot[nxy], is2ixy[is], ixy2ip[ixy], startnsz_per[ip], nst_per[ip], nst
2323
//
2424
void PW_Basis::distribution_method2()
2525
{
@@ -31,6 +31,7 @@ void PW_Basis::distribution_method2()
3131
int *st_bottom2D = NULL; // st_bottom2D[ixy], minimum z of stick on (x, y).
3232
int *st_length2D = NULL; // st_length2D[ixy], number of planewaves in stick on (x, y).
3333

34+
this->nst_per = new int[this->poolnproc]; // number of sticks on each core.
3435
if (poolrank == 0)
3536
{
3637
// (1) Count the total number of planewaves (tot_npw) and sticks (this->nstot).
@@ -60,9 +61,8 @@ void PW_Basis::distribution_method2()
6061

6162
// (2) Devide the sticks to each core, sticks are in the order of ixy increasing.
6263

63-
int *nst_per = new int[this->poolnproc]; // number of sticks on each core.
6464
ModuleBase::GlobalFunc::ZEROS(nst_per, this->poolnproc);
65-
this->divide_sticks2(nst_per);
65+
this->divide_sticks2();
6666
// for test ----------------------------------------------------------------------------
6767
std::cout << "nst_per ";
6868
for (int ip = 0; ip < this->poolnproc; ++ip) std::cout << nst_per[ip] << std::setw(4);
@@ -92,7 +92,6 @@ void PW_Basis::distribution_method2()
9292
}
9393
#endif
9494
delete[] npw_per;
95-
delete[] nst_per;
9695
}
9796
else
9897
{
@@ -106,19 +105,22 @@ void PW_Basis::distribution_method2()
106105
#ifdef __MPI
107106
MPI_Bcast(&tot_npw, 1, MPI_INT, 0, POOL_WORLD);
108107
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, POOL_WORLD);
108+
MPI_Bcast(&lix, 1, MPI_INT, 0, POOL_WORLD);
109+
MPI_Bcast(&rix, 1, MPI_INT, 0, POOL_WORLD);
110+
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , POOL_WORLD);
109111
if (this->poolrank != 0)
110112
{
111113
st_bottom2D = new int[this->nxy]; // minimum z of stick.
112114
st_length2D = new int[this->nxy]; // number of planewaves in stick.
113115
this->ixy2ip = new int[this->nxy]; // ip of core which contains stick on (x, y).
114116
this->ixy2istot = new int[this->nxy];
115-
this->istot2ixy = new int[this->nstot];
117+
this->istot2bigixy = new int[this->nstot];
116118
}
117119

118120
MPI_Bcast(st_length2D, this->nxy, MPI_INT, 0, POOL_WORLD);
119121
MPI_Bcast(st_bottom2D, this->nxy, MPI_INT, 0, POOL_WORLD);
120122
MPI_Bcast(this->ixy2ip, this->nxy, MPI_INT, 0, POOL_WORLD);
121-
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, POOL_WORLD);
123+
MPI_Bcast(this->istot2bigixy, this->nstot, MPI_INT, 0, POOL_WORLD);
122124
MPI_Bcast(this->ixy2istot, this->nxy, MPI_INT, 0, POOL_WORLD);
123125

124126
std::cout << "Bcast done\n";
@@ -142,9 +144,7 @@ void PW_Basis::distribution_method2()
142144
// known: this->nstot, this->poolnproc
143145
// output: nst_per, this->nstnz_per, this->startnsz_per
144146
//
145-
void PW_Basis::divide_sticks2(
146-
int* nst_per // number of sticks on each core.
147-
)
147+
void PW_Basis::divide_sticks2()
148148
{
149149
this->nstnz_per = new int[this->poolnproc]; // nz * nst(number of sticks) on each core.
150150
this->startnsz_per = new int[this->poolnproc];
@@ -165,7 +165,7 @@ void PW_Basis::divide_sticks2(
165165
//
166166
// (3) Create the maps from ixy to ip, istot, and from istot to ixy, and construt npw_per simultaneously.
167167
// known: st_length2D
168-
// output: this->ixy2ip, this->ixy2istot, this->istot2ixy, npw_per
168+
// output: this->ixy2ip, this->ixy2istot, this->istot2bigixy, npw_per
169169
//
170170
void PW_Basis::create_maps(
171171
int* st_length2D, // the number of planewaves that belong to the stick located on (x, y), stored in 2d x-y plane.
@@ -174,17 +174,17 @@ void PW_Basis::create_maps(
174174
{
175175
this->ixy2ip = new int[this->nxy];
176176
this->ixy2istot = new int[this->nxy];
177-
this->istot2ixy = new int[this->nstot];
177+
this->istot2bigixy = new int[this->nstot];
178178

179-
ModuleBase::GlobalFunc::ZEROS(this->istot2ixy, this->nstot);
179+
ModuleBase::GlobalFunc::ZEROS(this->istot2bigixy, this->nstot);
180180
int ip = 0;
181181
int st_move = 0; // the number of sticks that have been found.
182182
for (int ixy = 0; ixy < this->nxy; ++ixy)
183183
{
184184
if (st_length2D[ixy] > 0)
185185
{
186186
this->ixy2istot[ixy] = st_move;
187-
this->istot2ixy[st_move] = ixy;
187+
this->istot2bigixy[st_move] = ixy / ny * bigny + ixy % ny;
188188
this->ixy2ip[ixy] = ip;
189189
npw_per[ip] += st_length2D[ixy];
190190
st_move++;

0 commit comments

Comments
 (0)