@@ -27,9 +27,9 @@ FFT::~FFT()
2727 this ->cleanFFT ();
2828 if (c_gspace!=NULL ) fftw_free (c_gspace);
2929 if (c_rspace!=NULL ) fftw_free (c_rspace);
30- if (c_gspace2!=NULL ) fftw_free (c_gspace2);
31- if (c_rspace2!=NULL ) fftw_free (c_rspace2);
3230 if (r_rspace!=NULL ) fftw_free (r_rspace);
31+ // if(c_gspace2!=NULL) fftw_free(c_gspace2);
32+ // if(c_rspace2!=NULL) fftw_free(c_rspace2);
3333#ifdef __MIX_PRECISION
3434 if (cf_gspace!=NULL ) fftw_free (cf_gspace);
3535 if (cf_rspace!=NULL ) fftw_free (cf_rspace);
@@ -50,19 +50,27 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
5050 this ->mpifft = mpifft_in;
5151 this ->nxy = this ->nx * this -> ny;
5252 this ->bignxy = this ->nx * this ->bigny ;
53+ this ->maxgrids = (this ->nz * this ->ns > this ->bignxy * nplane) ? this ->nz * this ->ns : this ->bignxy * nplane;
5354 if (!this ->mpifft )
5455 {
55- // out-of- place fft is faster than in -place fft
56+ // It seems in- place fft is faster than out-of -place fft
5657 c_gspace = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * this ->nz * this ->ns );
57- c_gspace2 = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * this ->nz * this ->ns );
58- c_rspace = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * this ->bignxy * nplane);
58+ // c_gspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
5959 if (this ->gamma_only )
6060 {
61+ c_rspace = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * this ->bignxy * nplane);
6162 r_rspace = (double *) fftw_malloc (sizeof (double ) * this ->bignxy * nplane);
63+
64+ // r2c in place : It seems in-place r2c/c2r is much slower than out-of-place
65+ // int padnxyp = this->ny * 2 * this->nx * this->nplane;
66+ // c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * padnxyp);
67+ // r_rspace = (double *) c_rspace;
68+
6269 }
6370 else
6471 {
65- c_rspace2 = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * this ->bignxy * nplane);
72+ c_rspace = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * this ->bignxy * nplane);
73+ // c_rspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
6674 }
6775#ifdef __MIX_PRECISION
6876 cf_gspace = (std::complex <float > *)fftw_malloc (sizeof (fftwf_complex) * this ->nz * this ->ns );
@@ -111,13 +119,16 @@ void FFT :: initplan()
111119 // fftw_complex *in, const int *inembed, int istride, int idist,
112120 // fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned flags);
113121
114- this ->plan1for = fftw_plan_many_dft ( 1 , &this ->nz , this ->ns ,
122+ this ->planzfor = fftw_plan_many_dft ( 1 , &this ->nz , this ->ns ,
115123 (fftw_complex*) c_gspace, &this ->nz , 1 , this ->nz ,
116- (fftw_complex*) c_gspace2 , &this ->nz , 1 , this ->nz , FFTW_FORWARD, FFTW_MEASURE);
124+ (fftw_complex*) c_gspace , &this ->nz , 1 , this ->nz , FFTW_FORWARD, FFTW_MEASURE);
117125
118- this ->plan1bac = fftw_plan_many_dft ( 1 , &this ->nz , this ->ns ,
126+ this ->planzbac = fftw_plan_many_dft ( 1 , &this ->nz , this ->ns ,
119127 (fftw_complex*) c_gspace, &this ->nz , 1 , this ->nz ,
120- (fftw_complex*) c_gspace2, &this ->nz , 1 , this ->nz , FFTW_BACKWARD, FFTW_MEASURE);
128+ (fftw_complex*) c_gspace, &this ->nz , 1 , this ->nz , FFTW_BACKWARD, FFTW_MEASURE);
129+
130+ // this->planzfor = fftw_plan_dft_1d(this->nz,(fftw_complex*) c_gspace,(fftw_complex*) c_gspace, FFTW_FORWARD, FFTW_MEASURE);
131+ // this->planzbac = fftw_plan_dft_1d(this->nz,(fftw_complex*) c_gspace,(fftw_complex*) c_gspace,FFTW_BACKWARD, FFTW_MEASURE);
121132
122133 // ---------------------------------------------------------
123134 // 2 D
@@ -127,24 +138,63 @@ void FFT :: initplan()
127138 int *embed = NULL ;
128139 if (this ->gamma_only )
129140 {
130- this ->plan2r2c = fftw_plan_many_dft_r2c ( 2 , nrank, this ->nplane ,
131- r_rspace, embed, this ->nplane , 1 ,
132- (fftw_complex*) c_rspace, embed, this ->nplane , 1 , FFTW_MEASURE);
141+ // int padnpy = this->nplane * this->ny * 2;
142+ // int rankc[2] = {this->nx, this->padnpy};
143+ // int rankd[2] = {this->nx, this->padnpy*2};
144+ // // It seems 1D+1D is much faster than 2D FFT!
145+ // this->plan2r2c = fftw_plan_many_dft_r2c( 2, nrank, this->nplane,
146+ // r_rspace, rankd, this->nplane, 1,
147+ // (fftw_complex*) c_rspace, rankc, this->nplane, 1, FFTW_MEASURE);
148+
149+ // this->plan2c2r = fftw_plan_many_dft_c2r( 2, nrank, this->nplane,
150+ // (fftw_complex*) c_rspace, rankc, this->nplane, 1,
151+ // r_rspace, rankd, this->nplane, 1, FFTW_MEASURE);
152+
153+ int npy = this ->nplane * this ->ny ;
154+ int bignpy = this ->nplane * this ->bigny ;
155+ this ->planxfor = fftw_plan_many_dft ( 1 , &this ->nx , npy, (fftw_complex *)c_rspace, embed, bignpy, 1 ,
156+ (fftw_complex *)c_rspace, embed, bignpy, 1 , FFTW_FORWARD, FFTW_MEASURE );
157+ this ->planxbac = fftw_plan_many_dft ( 1 , &this ->nx , npy, (fftw_complex *)c_rspace, embed, bignpy, 1 ,
158+ (fftw_complex *)c_rspace, embed, bignpy, 1 , FFTW_BACKWARD, FFTW_MEASURE );
159+ this ->planyr2c = fftw_plan_many_dft_r2c ( 1 , &this ->bigny , this ->nplane , r_rspace , embed, this ->nplane , 1 ,
160+ (fftw_complex*)c_rspace, embed, this ->nplane , 1 , FFTW_MEASURE );
161+ this ->planyc2r = fftw_plan_many_dft_c2r ( 1 , &this ->bigny , this ->nplane , (fftw_complex*)c_rspace , embed, this ->nplane , 1 ,
162+ r_rspace, embed, this ->nplane , 1 , FFTW_MEASURE );
163+
164+ // int padnpy = npy * 2;
165+ // this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, padnpy, 1,
166+ // (fftw_complex *)c_rspace, embed, padnpy, 1, FFTW_FORWARD, FFTW_MEASURE );
167+ // this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, padnpy, 1,
168+ // (fftw_complex *)c_rspace, embed, padnpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
169+ // this->planyr2c = fftw_plan_many_dft_r2c( 1, &this->bigny, this->nplane, r_rspace , embed, this->nplane*2, 1,
170+ // (fftw_complex*)c_rspace, embed, this->nplane, 1, FFTW_MEASURE );
171+ // this->planyc2r = fftw_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftw_complex*)c_rspace , embed, this->nplane, 1,
172+ // r_rspace, embed, this->nplane*2, 1, FFTW_MEASURE );
133173
134- this ->plan2c2r = fftw_plan_many_dft_c2r ( 2 , nrank, this ->nplane ,
135- (fftw_complex*) c_rspace, embed, this ->nplane , 1 ,
136- r_rspace, embed, this ->nplane , 1 , FFTW_MEASURE);
137174 }
138175 else
139176 {
140- this ->plan2for = fftw_plan_many_dft ( 2 , nrank, this ->nplane ,
141- (fftw_complex*) c_rspace, embed, this ->nplane , 1 ,
142- (fftw_complex*) c_rspace2, embed, this ->nplane , 1 , FFTW_FORWARD, FFTW_MEASURE);
177+ // It seems 1D+1D is much faster than 2D FFT!
178+ // this->plan2for = fftw_plan_many_dft( 2, nrank, this->nplane,
179+ // (fftw_complex*) c_rspace, embed, this->nplane, 1,
180+ // (fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
143181
144- this ->plan2bac = fftw_plan_many_dft ( 2 , nrank, this ->nplane ,
145- (fftw_complex*) c_rspace, embed, this ->nplane , 1 ,
146- (fftw_complex*) c_rspace2, embed, this ->nplane , 1 , FFTW_BACKWARD, FFTW_MEASURE);
182+ // this->plan2bac = fftw_plan_many_dft( 2, nrank, this->nplane,
183+ // (fftw_complex*) c_rspace, embed, this->nplane, 1,
184+ // (fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
185+ int npy = this ->nplane * this ->ny ;
186+ this ->planxfor = fftw_plan_many_dft ( 1 , &this ->nx , npy, (fftw_complex *)c_rspace, embed, npy, 1 ,
187+ (fftw_complex *)c_rspace, embed, npy, 1 , FFTW_FORWARD, FFTW_MEASURE );
188+ this ->planxbac = fftw_plan_many_dft ( 1 , &this ->nx , npy, (fftw_complex *)c_rspace, embed, npy, 1 ,
189+ (fftw_complex *)c_rspace, embed, npy, 1 , FFTW_BACKWARD, FFTW_MEASURE );
190+ this ->planyfor = fftw_plan_many_dft ( 1 , &this ->ny , this ->nplane , (fftw_complex*)c_rspace , embed, this ->nplane , 1 ,
191+ (fftw_complex*)c_rspace, embed, this ->nplane , 1 , FFTW_FORWARD, FFTW_MEASURE );
192+ this ->planybac = fftw_plan_many_dft ( 1 , &this ->ny , this ->nplane , (fftw_complex*)c_rspace , embed, this ->nplane , 1 ,
193+ (fftw_complex*)c_rspace, embed, this ->nplane , 1 , FFTW_BACKWARD, FFTW_MEASURE );
147194 }
195+
196+
197+
148198 destroyp = false ;
149199}
150200
@@ -213,17 +263,19 @@ void FFT :: initplanf_mpi()
213263void FFT:: cleanFFT()
214264{
215265 if (destroyp==true ) return ;
216- fftw_destroy_plan (plan1for);
217- fftw_destroy_plan (plan1bac);
266+ fftw_destroy_plan (planzfor);
267+ fftw_destroy_plan (planzbac);
268+ fftw_destroy_plan (planxfor);
269+ fftw_destroy_plan (planxbac);
218270 if (this ->gamma_only )
219271 {
220- fftw_destroy_plan (plan2r2c );
221- fftw_destroy_plan (plan2c2r );
272+ fftw_destroy_plan (planyr2c );
273+ fftw_destroy_plan (planyc2r );
222274 }
223275 else
224276 {
225- fftw_destroy_plan (plan2for );
226- fftw_destroy_plan (plan2bac );
277+ fftw_destroy_plan (planyfor );
278+ fftw_destroy_plan (planybac );
227279 }
228280 destroyp = true ;
229281
@@ -250,17 +302,71 @@ void FFT:: cleanFFT()
250302void FFT::executefftw (std::string instr)
251303{
252304 if (instr == " 1for" )
253- fftw_execute (this ->plan1for );
254- else if (instr == " 2for" )
255- fftw_execute (this ->plan2for );
305+ {
306+ // for(int i = 0 ; i < this->ns ; ++i)
307+ // {
308+ // fftw_execute_dft(this->planzfor,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
309+ // }
310+ fftw_execute_dft (this ->planzfor ,(fftw_complex *)c_gspace,(fftw_complex *)c_gspace);
311+ // fftw_execute(this->planzfor);
312+ }
256313 else if (instr == " 1bac" )
257- fftw_execute (this ->plan1bac );
314+ {
315+ // for(int i = 0 ; i < this->ns ; ++i)
316+ // {
317+ // fftw_execute_dft(this->planzbac,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
318+ // }
319+ fftw_execute_dft (this ->planzbac ,(fftw_complex *)c_gspace,(fftw_complex *)c_gspace);
320+ // fftw_execute(this->planzbac);
321+ }
322+ else if (instr == " 2for" )
323+ {
324+ int npy = this ->nplane * this -> ny;
325+ fftw_execute_dft ( this ->planxfor , (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
326+ for (int i=0 ; i<this ->nx ;++i)
327+ {
328+ fftw_execute_dft ( this ->planyfor , (fftw_complex*)&c_rspace[i*npy], (fftw_complex*)&c_rspace[i*npy] );
329+ }
330+ // fftw_execute(this->plan2for);
331+ }
258332 else if (instr == " 2bac" )
259- fftw_execute (this ->plan2bac );
333+ {
334+ // fftw_execute(this->plan2bac);
335+ int npy = this ->nplane * this -> ny;
336+
337+ for (int i=0 ; i<this ->nx ;++i)
338+ {
339+ fftw_execute_dft ( this ->planybac , (fftw_complex*)&c_rspace[i*npy], (fftw_complex*)&c_rspace[i*npy] );
340+ }
341+ fftw_execute_dft ( this ->planxbac , (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
342+
343+ }
260344 else if (instr == " 2r2c" )
261- fftw_execute (this ->plan2r2c );
345+ {
346+ // fftw_execute(this->plan2r2c);
347+ // int npy = this->nplane * this-> ny;
348+ int bignpy = this ->nplane * this -> bigny;
349+ // int padnpy = this->nplane * this-> ny * 2;
350+ for (int i=0 ; i<this ->nx ;++i)
351+ {
352+ fftw_execute_dft_r2c ( this ->planyr2c , &r_rspace[i*bignpy], (fftw_complex*)&c_rspace[i*bignpy] );
353+ // fftw_execute_dft_r2c( this->planyfor, &r_rspace[4*i*padnpy], (fftw_complex*)&c_rspace[i*padnpy] );
354+ }
355+ fftw_execute_dft ( this ->planxfor , (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
356+ }
262357 else if (instr == " 2c2r" )
263- fftw_execute (this ->plan2c2r );
358+ {
359+ // fftw_execute(this->plan2c2r);
360+ // int npy = this->nplane * this-> ny;
361+ int bignpy = this ->nplane * this -> bigny;
362+ // int padnpy = this->nplane * this-> ny * 2;
363+ fftw_execute_dft ( this ->planxbac , (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
364+ for (int i=0 ; i<this ->nx ;++i)
365+ {
366+ fftw_execute_dft_c2r ( this ->planyc2r , (fftw_complex*)&c_rspace[i*bignpy], &r_rspace[i*bignpy] );
367+ // fftw_execute_dft_c2r( this->planybac, (fftw_complex*)&c_rspace[i*padnpy], &r_rspace[4*i*padnpy] );
368+ }
369+ }
264370 else
265371 {
266372 ModuleBase::WARNING_QUIT (" FFT" , " Wrong input for excutefftw" );
0 commit comments