Skip to content

Commit 1cd2739

Browse files
authored
Merge pull request #535 from Qianruipku/planewave
< prof >
2 parents e76341d + 767f3ce commit 1cd2739

File tree

10 files changed

+416
-55
lines changed

10 files changed

+416
-55
lines changed

source/module_base/global_function.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ void DONE(std::ofstream &ofs,const std::string &description, bool only_rank0 = f
116116
template<class T, class TI>
117117
inline void ZEROS(std::complex<T> *u,const TI n) // Peize Lin change int to TI at 2020.03.03
118118
{
119-
assert(n>=0);
120-
assert(u!=0);
119+
if(n <= 0) return;
121120
for (TI i=0;i<n;i++)
122121
{
123122
u[i] = std::complex<T>(0.0,0.0);

source/module_pw/fft.cpp

Lines changed: 141 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ FFT::~FFT()
2727
this->cleanFFT();
2828
if(c_gspace!=NULL) fftw_free(c_gspace);
2929
if(c_rspace!=NULL) fftw_free(c_rspace);
30-
if(c_gspace2!=NULL) fftw_free(c_gspace2);
31-
if(c_rspace2!=NULL) fftw_free(c_rspace2);
3230
if(r_rspace!=NULL) fftw_free(r_rspace);
31+
// if(c_gspace2!=NULL) fftw_free(c_gspace2);
32+
// if(c_rspace2!=NULL) fftw_free(c_rspace2);
3333
#ifdef __MIX_PRECISION
3434
if(cf_gspace!=NULL) fftw_free(cf_gspace);
3535
if(cf_rspace!=NULL) fftw_free(cf_rspace);
@@ -50,19 +50,27 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
5050
this->mpifft = mpifft_in;
5151
this->nxy = this->nx * this-> ny;
5252
this->bignxy = this->nx * this->bigny;
53+
this->maxgrids = (this->nz * this->ns > this->bignxy * nplane) ? this->nz * this->ns : this->bignxy * nplane;
5354
if(!this->mpifft)
5455
{
55-
//out-of-place fft is faster than in-place fft
56+
//It seems in-place fft is faster than out-of-place fft
5657
c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
57-
c_gspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
58-
c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
58+
//c_gspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
5959
if(this->gamma_only)
6060
{
61+
c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
6162
r_rspace = (double *) fftw_malloc(sizeof(double) * this->bignxy * nplane);
63+
64+
//r2c in place : It seems in-place r2c/c2r is much slower than out-of-place
65+
// int padnxyp = this->ny * 2 * this->nx * this->nplane;
66+
// c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * padnxyp);
67+
// r_rspace = (double *) c_rspace;
68+
6269
}
6370
else
6471
{
65-
c_rspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
72+
c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
73+
//c_rspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
6674
}
6775
#ifdef __MIX_PRECISION
6876
cf_gspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->nz * this->ns);
@@ -111,13 +119,16 @@ void FFT :: initplan()
111119
// fftw_complex *in, const int *inembed, int istride, int idist,
112120
// fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned flags);
113121

114-
this->plan1for = fftw_plan_many_dft( 1, &this->nz, this->ns,
122+
this->planzfor = fftw_plan_many_dft( 1, &this->nz, this->ns,
115123
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
116-
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
124+
(fftw_complex*) c_gspace, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
117125

118-
this->plan1bac = fftw_plan_many_dft( 1, &this->nz, this->ns,
126+
this->planzbac = fftw_plan_many_dft( 1, &this->nz, this->ns,
119127
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
120-
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
128+
(fftw_complex*) c_gspace, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
129+
130+
// this->planzfor = fftw_plan_dft_1d(this->nz,(fftw_complex*) c_gspace,(fftw_complex*) c_gspace, FFTW_FORWARD, FFTW_MEASURE);
131+
// this->planzbac = fftw_plan_dft_1d(this->nz,(fftw_complex*) c_gspace,(fftw_complex*) c_gspace,FFTW_BACKWARD, FFTW_MEASURE);
121132

122133
//---------------------------------------------------------
123134
// 2 D
@@ -127,24 +138,63 @@ void FFT :: initplan()
127138
int *embed = NULL;
128139
if(this->gamma_only)
129140
{
130-
this->plan2r2c = fftw_plan_many_dft_r2c( 2, nrank, this->nplane,
131-
r_rspace, embed, this->nplane, 1,
132-
(fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_MEASURE);
141+
// int padnpy = this->nplane * this->ny * 2;
142+
// int rankc[2] = {this->nx, this->padnpy};
143+
// int rankd[2] = {this->nx, this->padnpy*2};
144+
// // It seems 1D+1D is much faster than 2D FFT!
145+
// this->plan2r2c = fftw_plan_many_dft_r2c( 2, nrank, this->nplane,
146+
// r_rspace, rankd, this->nplane, 1,
147+
// (fftw_complex*) c_rspace, rankc, this->nplane, 1, FFTW_MEASURE);
148+
149+
// this->plan2c2r = fftw_plan_many_dft_c2r( 2, nrank, this->nplane,
150+
// (fftw_complex*) c_rspace, rankc, this->nplane, 1,
151+
// r_rspace, rankd, this->nplane, 1, FFTW_MEASURE);
152+
153+
int npy = this->nplane * this->ny;
154+
int bignpy = this->nplane * this->bigny;
155+
this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, bignpy, 1,
156+
(fftw_complex *)c_rspace, embed, bignpy, 1, FFTW_FORWARD, FFTW_MEASURE );
157+
this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, bignpy, 1,
158+
(fftw_complex *)c_rspace, embed, bignpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
159+
this->planyr2c = fftw_plan_many_dft_r2c( 1, &this->bigny, this->nplane, r_rspace , embed, this->nplane, 1,
160+
(fftw_complex*)c_rspace, embed, this->nplane, 1, FFTW_MEASURE );
161+
this->planyc2r = fftw_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftw_complex*)c_rspace , embed, this->nplane, 1,
162+
r_rspace, embed, this->nplane, 1, FFTW_MEASURE );
163+
164+
// int padnpy = npy * 2;
165+
// this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, padnpy, 1,
166+
// (fftw_complex *)c_rspace, embed, padnpy, 1, FFTW_FORWARD, FFTW_MEASURE );
167+
// this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, padnpy, 1,
168+
// (fftw_complex *)c_rspace, embed, padnpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
169+
// this->planyr2c = fftw_plan_many_dft_r2c( 1, &this->bigny, this->nplane, r_rspace , embed, this->nplane*2, 1,
170+
// (fftw_complex*)c_rspace, embed, this->nplane, 1, FFTW_MEASURE );
171+
// this->planyc2r = fftw_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftw_complex*)c_rspace , embed, this->nplane, 1,
172+
// r_rspace, embed, this->nplane*2, 1, FFTW_MEASURE );
133173

134-
this->plan2c2r = fftw_plan_many_dft_c2r( 2, nrank, this->nplane,
135-
(fftw_complex*) c_rspace, embed, this->nplane, 1,
136-
r_rspace, embed, this->nplane, 1, FFTW_MEASURE);
137174
}
138175
else
139176
{
140-
this->plan2for = fftw_plan_many_dft( 2, nrank, this->nplane,
141-
(fftw_complex*) c_rspace, embed, this->nplane, 1,
142-
(fftw_complex*) c_rspace2, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
177+
// It seems 1D+1D is much faster than 2D FFT!
178+
// this->plan2for = fftw_plan_many_dft( 2, nrank, this->nplane,
179+
// (fftw_complex*) c_rspace, embed, this->nplane, 1,
180+
// (fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
143181

144-
this->plan2bac = fftw_plan_many_dft( 2, nrank, this->nplane,
145-
(fftw_complex*) c_rspace, embed, this->nplane, 1,
146-
(fftw_complex*) c_rspace2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
182+
// this->plan2bac = fftw_plan_many_dft( 2, nrank, this->nplane,
183+
// (fftw_complex*) c_rspace, embed, this->nplane, 1,
184+
// (fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
185+
int npy = this->nplane * this->ny;
186+
this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, npy, 1,
187+
(fftw_complex *)c_rspace, embed, npy, 1, FFTW_FORWARD, FFTW_MEASURE );
188+
this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, npy, 1,
189+
(fftw_complex *)c_rspace, embed, npy, 1, FFTW_BACKWARD, FFTW_MEASURE );
190+
this->planyfor = fftw_plan_many_dft( 1, &this->ny, this->nplane, (fftw_complex*)c_rspace , embed, this->nplane, 1,
191+
(fftw_complex*)c_rspace, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE );
192+
this->planybac = fftw_plan_many_dft( 1, &this->ny, this->nplane, (fftw_complex*)c_rspace , embed, this->nplane, 1,
193+
(fftw_complex*)c_rspace, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE );
147194
}
195+
196+
197+
148198
destroyp = false;
149199
}
150200

@@ -213,17 +263,19 @@ void FFT :: initplanf_mpi()
213263
void FFT:: cleanFFT()
214264
{
215265
if(destroyp==true) return;
216-
fftw_destroy_plan(plan1for);
217-
fftw_destroy_plan(plan1bac);
266+
fftw_destroy_plan(planzfor);
267+
fftw_destroy_plan(planzbac);
268+
fftw_destroy_plan(planxfor);
269+
fftw_destroy_plan(planxbac);
218270
if(this->gamma_only)
219271
{
220-
fftw_destroy_plan(plan2r2c);
221-
fftw_destroy_plan(plan2c2r);
272+
fftw_destroy_plan(planyr2c);
273+
fftw_destroy_plan(planyc2r);
222274
}
223275
else
224276
{
225-
fftw_destroy_plan(plan2for);
226-
fftw_destroy_plan(plan2bac);
277+
fftw_destroy_plan(planyfor);
278+
fftw_destroy_plan(planybac);
227279
}
228280
destroyp = true;
229281

@@ -250,17 +302,71 @@ void FFT:: cleanFFT()
250302
void FFT::executefftw(std::string instr)
251303
{
252304
if(instr == "1for")
253-
fftw_execute(this->plan1for);
254-
else if(instr == "2for")
255-
fftw_execute(this->plan2for);
305+
{
306+
// for(int i = 0 ; i < this->ns ; ++i)
307+
// {
308+
// fftw_execute_dft(this->planzfor,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
309+
// }
310+
fftw_execute_dft(this->planzfor,(fftw_complex *)c_gspace,(fftw_complex *)c_gspace);
311+
// fftw_execute(this->planzfor);
312+
}
256313
else if(instr == "1bac")
257-
fftw_execute(this->plan1bac);
314+
{
315+
// for(int i = 0 ; i < this->ns ; ++i)
316+
// {
317+
// fftw_execute_dft(this->planzbac,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
318+
// }
319+
fftw_execute_dft(this->planzbac,(fftw_complex *)c_gspace,(fftw_complex *)c_gspace);
320+
// fftw_execute(this->planzbac);
321+
}
322+
else if(instr == "2for")
323+
{
324+
int npy = this->nplane * this-> ny;
325+
fftw_execute_dft( this->planxfor, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
326+
for (int i=0; i<this->nx;++i)
327+
{
328+
fftw_execute_dft( this->planyfor, (fftw_complex*)&c_rspace[i*npy], (fftw_complex*)&c_rspace[i*npy] );
329+
}
330+
// fftw_execute(this->plan2for);
331+
}
258332
else if(instr == "2bac")
259-
fftw_execute(this->plan2bac);
333+
{
334+
// fftw_execute(this->plan2bac);
335+
int npy = this->nplane * this-> ny;
336+
337+
for (int i=0; i<this->nx;++i)
338+
{
339+
fftw_execute_dft( this->planybac, (fftw_complex*)&c_rspace[i*npy], (fftw_complex*)&c_rspace[i*npy] );
340+
}
341+
fftw_execute_dft( this->planxbac, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
342+
343+
}
260344
else if(instr == "2r2c")
261-
fftw_execute(this->plan2r2c);
345+
{
346+
// fftw_execute(this->plan2r2c);
347+
//int npy = this->nplane * this-> ny;
348+
int bignpy = this->nplane * this-> bigny;
349+
// int padnpy = this->nplane * this-> ny * 2;
350+
for (int i=0; i<this->nx;++i)
351+
{
352+
fftw_execute_dft_r2c( this->planyr2c, &r_rspace[i*bignpy], (fftw_complex*)&c_rspace[i*bignpy] );
353+
// fftw_execute_dft_r2c( this->planyfor, &r_rspace[4*i*padnpy], (fftw_complex*)&c_rspace[i*padnpy] );
354+
}
355+
fftw_execute_dft( this->planxfor, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
356+
}
262357
else if(instr == "2c2r")
263-
fftw_execute(this->plan2c2r);
358+
{
359+
// fftw_execute(this->plan2c2r);
360+
//int npy = this->nplane * this-> ny;
361+
int bignpy = this->nplane * this-> bigny;
362+
// int padnpy = this->nplane * this-> ny * 2;
363+
fftw_execute_dft( this->planxbac, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
364+
for (int i=0; i<this->nx;++i)
365+
{
366+
fftw_execute_dft_c2r( this->planyc2r, (fftw_complex*)&c_rspace[i*bignpy], &r_rspace[i*bignpy] );
367+
// fftw_execute_dft_c2r( this->planybac, (fftw_complex*)&c_rspace[i*padnpy], &r_rspace[4*i*padnpy] );
368+
}
369+
}
264370
else
265371
{
266372
ModuleBase::WARNING_QUIT("FFT", "Wrong input for excutefftw");

source/module_pw/fft.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class FFT
5151
int bignxy;
5252
int ns; //number of sticks
5353
int nplane; //number of x-y planes
54+
int maxgrids; // max between nz * ns and bignxy * nplane
5455
std::complex<double> * c_gspace, *c_gspace2; //complex number space for g, [ns * nz]
5556
std::complex<double> * c_rspace, *c_rspace2;//complex number space for r, [nplane * nx *ny]
5657
double *r_rspace; //real number space for r, [nplane * nx *ny]
@@ -65,12 +66,20 @@ class FFT
6566
bool gamma_only;
6667
bool destroyp;
6768
bool mpifft; // if use mpi fft, only used when define __FFTW3_MPI
68-
fftw_plan plan2r2c;
69-
fftw_plan plan2c2r;
70-
fftw_plan plan1for;
71-
fftw_plan plan1bac;
72-
fftw_plan plan2for;
73-
fftw_plan plan2bac;
69+
// fftw_plan plan2r2c;
70+
// fftw_plan plan2c2r;
71+
// fftw_plan plan1for;
72+
// fftw_plan plan1bac;
73+
// fftw_plan plan2for;
74+
// fftw_plan plan2bac;
75+
fftw_plan planzfor;
76+
fftw_plan planzbac;
77+
fftw_plan planxfor;
78+
fftw_plan planxbac;
79+
fftw_plan planyfor;
80+
fftw_plan planybac;
81+
fftw_plan planyr2c;
82+
fftw_plan planyc2r;
7483
#ifdef __MIX_PRECISION
7584
bool destroypf;
7685
fftwf_plan planf2r2c;

source/module_pw/pw_gatherscatter.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ void PW_Basis:: gatherp_scatters(complex<double> *in, complex<double> *out)
1515
for(int is = 0 ; is < this->nst ; ++is)
1616
{
1717
int ixy = this->is2ixy[is];
18+
int bigixy = (ixy / ny)*bigny + ixy % ny;
1819
for(int iz = 0 ; iz < this->nz ; ++iz)
1920
{
20-
out[is*nz+iz] = in[ixy*nz+iz];
21+
out[is*nz+iz] = in[bigixy*nz+iz];
2122
}
2223
}
2324
return;
@@ -32,7 +33,8 @@ void PW_Basis:: gatherp_scatters(complex<double> *in, complex<double> *out)
3233
if(this->ixy2ip[ixy] == -1) continue;
3334
int istot = 0;
3435
if(this->poolrank == 0) istot = this->ixy2istot[ixy];
35-
MPI_Gatherv(&in[ixy*this->nplane], this->nplane, mpicomplex, &tmp[istot*this->nz],
36+
int bigixy = (ixy / ny)*bigny + ixy % ny;
37+
MPI_Gatherv(&in[bigixy*this->nplane], this->nplane, mpicomplex, &tmp[istot*this->nz],
3638
this->numz,this->startz,mpicomplex,0,POOL_WORLD);
3739
}
3840

@@ -98,9 +100,10 @@ void PW_Basis:: gathers_scatterp(complex<double> *in, complex<double> *out)
98100
for(int is = 0 ; is < this->nst ; ++is)
99101
{
100102
int ixy = is2ixy[is];
103+
int bigixy = (ixy / ny)*bigny + ixy % ny;
101104
for(int iz = 0 ; iz < this->nz ; ++iz)
102105
{
103-
out[ixy*nz+iz] = in[is*nz+iz];
106+
out[bigixy*nz+iz] = in[is*nz+iz];
104107
}
105108
}
106109
return;
@@ -118,7 +121,8 @@ void PW_Basis:: gathers_scatterp(complex<double> *in, complex<double> *out)
118121
for(int istot = 0 ; istot < this->nstot ; ++istot)
119122
{
120123
int ixy = this->istot2ixy[istot];
121-
MPI_Scatterv(&tmp[istot*this->nz], this->numz,this->startz, mpicomplex, &out[ixy*this->nplane],
124+
int bigixy = (ixy / ny)*bigny + ixy % ny;
125+
MPI_Scatterv(&tmp[istot*this->nz], this->numz,this->startz, mpicomplex, &out[bigixy*this->nplane],
122126
this->nplane,mpicomplex,0,POOL_WORLD);
123127
}
124128
if(tmp!=NULL) delete[] tmp;

0 commit comments

Comments
 (0)