11#include " fft.h"
2- #include " ../module_base/tool_quit.h"
32namespace ModulePW
43{
54
@@ -12,8 +11,7 @@ FFT::FFT()
1211 mpifft = false ;
1312 destroyp = true ;
1413 gamma_only = false ;
15- c_rspace = c_gspace = NULL ;
16- c_rspace2 = c_gspace2 = NULL ;
14+ aux2 = aux1 = NULL ;
1715 r_rspace = NULL ;
1816#ifdef __MIX_PRECISION
1917 destroypf = true ;
@@ -25,19 +23,16 @@ FFT::FFT()
2523FFT::~FFT ()
2624{
2725 this ->cleanFFT ();
28- if (c_gspace!=NULL ) fftw_free (c_gspace);
29- if (c_rspace!=NULL ) fftw_free (c_rspace);
30- if (c_gspace2!=NULL ) fftw_free (c_gspace2);
31- if (c_rspace2!=NULL ) fftw_free (c_rspace2);
32- if (r_rspace!=NULL ) fftw_free (r_rspace);
26+ if (aux1!=NULL ) fftw_free (aux1);
27+ if (aux2!=NULL ) fftw_free (aux2);
3328#ifdef __MIX_PRECISION
3429 if (cf_gspace!=NULL ) fftw_free (cf_gspace);
3530 if (cf_rspace!=NULL ) fftw_free (cf_rspace);
3631 if (rf_rspace!=NULL ) fftw_free (rf_rspace);
3732#endif
3833}
3934
40- void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, bool gamma_only_in, bool mpifft_in)
35+ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int liy_in, int riy_in, int ns_in, int nplane_in, int nproc_in , bool gamma_only_in, bool mpifft_in)
4136{
4237 this ->gamma_only = gamma_only_in;
4338 this ->nx = nx_in;
@@ -46,24 +41,19 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
4641 else this ->ny = this ->bigny ;
4742 this ->nz = nz_in;
4843 this ->ns = ns_in;
44+ this ->liy = liy_in;
45+ this ->riy = riy_in;
4946 this ->nplane = nplane_in;
47+ this ->nproc = nproc_in;
5048 this ->mpifft = mpifft_in;
5149 this ->nxy = this ->nx * this -> ny;
5250 this ->bignxy = this ->nx * this ->bigny ;
51+ this ->maxgrids = (this ->nz * this ->ns > this ->bignxy * nplane) ? this ->nz * this ->ns : this ->bignxy * nplane;
5352 if (!this ->mpifft )
5453 {
55- // out-of-place fft is faster than in-place fft
56- c_gspace = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * this ->nz * this ->ns );
57- c_gspace2 = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * this ->nz * this ->ns );
58- c_rspace = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * this ->bignxy * nplane);
59- if (this ->gamma_only )
60- {
61- r_rspace = (double *) fftw_malloc (sizeof (double ) * this ->bignxy * nplane);
62- }
63- else
64- {
65- c_rspace2 = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * this ->bignxy * nplane);
66- }
54+ aux1 = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * maxgrids);
55+ aux2 = (std::complex <double > *) fftw_malloc (sizeof (fftw_complex) * maxgrids);
56+ r_rspace = (double *) aux1;
6757#ifdef __MIX_PRECISION
6858 cf_gspace = (std::complex <float > *)fftw_malloc (sizeof (fftwf_complex) * this ->nz * this ->ns );
6959 cf_rspace = (std::complex <float > *)fftw_malloc (sizeof (fftwf_complex) * this ->bignxy * nplane);
@@ -111,40 +101,85 @@ void FFT :: initplan()
111101 // fftw_complex *in, const int *inembed, int istride, int idist,
112102 // fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned flags);
113103
114- this ->plan1for = fftw_plan_many_dft ( 1 , &this ->nz , this ->ns ,
115- (fftw_complex*) c_gspace, &this ->nz , 1 , this ->nz ,
116- (fftw_complex*) c_gspace2, &this ->nz , 1 , this ->nz , FFTW_FORWARD, FFTW_MEASURE);
104+ // It is better to use out-of-place fft for stride = 1.
105+ this ->planzfor = fftw_plan_many_dft ( 1 , &this ->nz , this ->ns ,
106+ (fftw_complex*) aux1, &this ->nz , 1 , this ->nz ,
107+ (fftw_complex*) aux2, &this ->nz , 1 , this ->nz , FFTW_FORWARD, FFTW_MEASURE);
117108
118- this ->plan1bac = fftw_plan_many_dft ( 1 , &this ->nz , this ->ns ,
119- (fftw_complex*) c_gspace, &this ->nz , 1 , this ->nz ,
120- (fftw_complex*) c_gspace2, &this ->nz , 1 , this ->nz , FFTW_BACKWARD, FFTW_MEASURE);
109+ this ->planzbac = fftw_plan_many_dft ( 1 , &this ->nz , this ->ns ,
110+ (fftw_complex*) aux1, &this ->nz , 1 , this ->nz ,
111+ (fftw_complex*) aux2, &this ->nz , 1 , this ->nz , FFTW_BACKWARD, FFTW_MEASURE);
112+
113+ // this->planzfor = fftw_plan_dft_1d(this->nz,(fftw_complex*) aux1,(fftw_complex*) aux1, FFTW_FORWARD, FFTW_MEASURE);
114+ // this->planzbac = fftw_plan_dft_1d(this->nz,(fftw_complex*) aux1,(fftw_complex*) aux1,FFTW_BACKWARD, FFTW_MEASURE);
121115
122116 // ---------------------------------------------------------
123117 // 2 D
124118 // ---------------------------------------------------------
125119
126- int nrank[2 ] = {this ->nx ,this ->bigny };
120+ // int nrank[2] = {this->nx,this->bigny};
127121 int *embed = NULL ;
122+ // It seems 1D+1D is much faster than 2D FFT!
128123 if (this ->gamma_only )
129124 {
130- this ->plan2r2c = fftw_plan_many_dft_r2c ( 2 , nrank, this ->nplane ,
131- r_rspace, embed, this ->nplane , 1 ,
132- (fftw_complex*) c_rspace, embed, this ->nplane , 1 , FFTW_MEASURE);
125+ // int padnpy = this->nplane * this->ny * 2;
126+ // int rankc[2] = {this->nx, this->padnpy};
127+ // int rankd[2] = {this->nx, this->padnpy*2};
128+ // // It seems 1D+1D is much faster than 2D FFT!
129+ // this->plan2r2c = fftw_plan_many_dft_r2c( 2, nrank, this->nplane,
130+ // r_rspace, rankd, this->nplane, 1,
131+ // (fftw_complex*) aux2, rankc, this->nplane, 1, FFTW_MEASURE);
132+
133+ // this->plan2c2r = fftw_plan_many_dft_c2r( 2, nrank, this->nplane,
134+ // (fftw_complex*) aux2, rankc, this->nplane, 1,
135+ // r_rspace, rankd, this->nplane, 1, FFTW_MEASURE);
136+
137+ // int npy = this->nplane * this->ny;
138+ int bignpy = this ->nplane * this ->bigny ;
139+ this ->planxfor = fftw_plan_many_dft ( 1 , &this ->nx , this ->nplane , (fftw_complex *)aux2, embed, bignpy, 1 ,
140+ (fftw_complex *)aux2, embed, bignpy, 1 , FFTW_FORWARD, FFTW_MEASURE );
141+ this ->planxbac = fftw_plan_many_dft ( 1 , &this ->nx , this ->nplane , (fftw_complex *)aux2, embed, bignpy, 1 ,
142+ (fftw_complex *)aux2, embed, bignpy, 1 , FFTW_BACKWARD, FFTW_MEASURE );
143+ this ->planyr2c = fftw_plan_many_dft_r2c ( 1 , &this ->bigny , this ->nplane , r_rspace , embed, this ->nplane , 1 ,
144+ (fftw_complex*)aux1, embed, this ->nplane , 1 , FFTW_MEASURE );
145+ this ->planyc2r = fftw_plan_many_dft_c2r ( 1 , &this ->bigny , this ->nplane , (fftw_complex*)aux1 , embed, this ->nplane , 1 ,
146+ r_rspace, embed, this ->nplane , 1 , FFTW_MEASURE );
147+
148+ // int padnpy = npy * 2;
149+ // this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, padnpy, 1,
150+ // (fftw_complex *)aux2, embed, padnpy, 1, FFTW_FORWARD, FFTW_MEASURE );
151+ // this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, padnpy, 1,
152+ // (fftw_complex *)aux2, embed, padnpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
153+ // this->planyr2c = fftw_plan_many_dft_r2c( 1, &this->bigny, this->nplane, r_rspace , embed, this->nplane*2, 1,
154+ // (fftw_complex*)aux2, embed, this->nplane, 1, FFTW_MEASURE );
155+ // this->planyc2r = fftw_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
156+ // r_rspace, embed, this->nplane*2, 1, FFTW_MEASURE );
133157
134- this ->plan2c2r = fftw_plan_many_dft_c2r ( 2 , nrank, this ->nplane ,
135- (fftw_complex*) c_rspace, embed, this ->nplane , 1 ,
136- r_rspace, embed, this ->nplane , 1 , FFTW_MEASURE);
137158 }
138159 else
139160 {
140- this ->plan2for = fftw_plan_many_dft ( 2 , nrank, this ->nplane ,
141- (fftw_complex*) c_rspace , embed, this ->nplane , 1 ,
142- (fftw_complex*) c_rspace2 , embed, this ->nplane , 1 , FFTW_FORWARD, FFTW_MEASURE);
161+ // this->plan2for = fftw_plan_many_dft( 2, nrank, this->nplane,
162+ // (fftw_complex*) aux2 , embed, this->nplane, 1,
163+ // (fftw_complex*) aux2 , embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
143164
144- this ->plan2bac = fftw_plan_many_dft ( 2 , nrank, this ->nplane ,
145- (fftw_complex*) c_rspace, embed, this ->nplane , 1 ,
146- (fftw_complex*) c_rspace2, embed, this ->nplane , 1 , FFTW_BACKWARD, FFTW_MEASURE);
165+ // this->plan2bac = fftw_plan_many_dft( 2, nrank, this->nplane,
166+ // (fftw_complex*) aux2, embed, this->nplane, 1,
167+ // (fftw_complex*) aux2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
168+
169+ // It is better to use in-place for stride > 1
170+ int bignpy = this ->nplane * this ->bigny ;
171+ this ->planxfor = fftw_plan_many_dft ( 1 , &this ->nx , this ->nplane , (fftw_complex *)aux2, embed, bignpy, 1 ,
172+ (fftw_complex *)aux2, embed, bignpy, 1 , FFTW_FORWARD, FFTW_MEASURE );
173+ this ->planxbac = fftw_plan_many_dft ( 1 , &this ->nx , this ->nplane , (fftw_complex *)aux2, embed, bignpy, 1 ,
174+ (fftw_complex *)aux2, embed, bignpy, 1 , FFTW_BACKWARD, FFTW_MEASURE );
175+ this ->planyfor = fftw_plan_many_dft ( 1 , &this ->bigny , this ->nplane , (fftw_complex*)aux2 , embed, this ->nplane , 1 ,
176+ (fftw_complex*)aux2, embed, this ->nplane , 1 , FFTW_FORWARD, FFTW_MEASURE );
177+ this ->planybac = fftw_plan_many_dft ( 1 , &this ->bigny , this ->nplane , (fftw_complex*)aux2 , embed, this ->nplane , 1 ,
178+ (fftw_complex*)aux2, embed, this ->nplane , 1 , FFTW_BACKWARD, FFTW_MEASURE );
147179 }
180+
181+
182+
148183 destroyp = false ;
149184}
150185
@@ -160,12 +195,12 @@ void FFT :: initplanf()
160195 // fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned flags);
161196
162197 this ->planf1for = fftwf_plan_many_dft ( 1 , &this ->nz , this ->ns ,
163- (fftwf_complex*)c_gspace , &this ->nz , 1 , this ->nz ,
164- (fftwf_complex*)c_gspace , &this ->nz , 1 , this ->nz , FFTW_FORWARD, FFTW_MEASURE);
198+ (fftwf_complex*)aux1 , &this ->nz , 1 , this ->nz ,
199+ (fftwf_complex*)aux1 , &this ->nz , 1 , this ->nz , FFTW_FORWARD, FFTW_MEASURE);
165200
166201 this ->planf1bac = fftwf_plan_many_dft ( 1 , &this ->nz , this ->ns ,
167- (fftwf_complex*)c_gspace , &this ->nz , 1 , this ->nz ,
168- (fftwf_complex*)c_gspace , &this ->nz , 1 , this ->nz , FFTW_BACKWARD, FFTW_MEASURE);
202+ (fftwf_complex*)aux1 , &this ->nz , 1 , this ->nz ,
203+ (fftwf_complex*)aux1 , &this ->nz , 1 , this ->nz , FFTW_BACKWARD, FFTW_MEASURE);
169204
170205
171206 // ---------------------------------------------------------
@@ -178,21 +213,21 @@ void FFT :: initplanf()
178213 {
179214 this ->planf2r2c = fftwf_plan_many_dft_r2c ( 2 , nrank, this ->nplane ,
180215 r_rspace, nrank, this ->nplane , 1 ,
181- (fftwf_complex*)c_rspace , nrank, this ->nplane , 1 , FFTW_MEASURE);
216+ (fftwf_complex*)aux2 , nrank, this ->nplane , 1 , FFTW_MEASURE);
182217
183218 this ->planf2c2r = fftwf_plan_many_dft_c2r ( 2 , nrank, this ->nplane ,
184- (fftwf_complex*)c_rspace , nrank, this ->nplane , 1 ,
219+ (fftwf_complex*)aux2 , nrank, this ->nplane , 1 ,
185220 r_rspace, nrank, this ->nplane , 1 , FFTW_MEASURE);
186221 }
187222 else
188223 {
189224 this ->planf2for = fftwf_plan_many_dft ( 2 , nrank, this ->nplane ,
190- (fftwf_complex*)c_rspace , nrank, this ->nplane , 1 ,
191- (fftwf_complex*)c_rspace , nrank, this ->nplane , 1 , FFTW_FORWARD, FFTW_MEASURE);
225+ (fftwf_complex*)aux2 , nrank, this ->nplane , 1 ,
226+ (fftwf_complex*)aux2 , nrank, this ->nplane , 1 , FFTW_FORWARD, FFTW_MEASURE);
192227
193228 this ->planf2bac = fftwf_plan_many_dft ( 2 , nrank, this ->nplane ,
194- (fftwf_complex*)c_rspace , nrank, this ->nplane , 1 ,
195- (fftwf_complex*)c_rspace , nrank, this ->nplane , 1 , FFTW_BACKWARD, FFTW_MEASURE);
229+ (fftwf_complex*)aux2 , nrank, this ->nplane , 1 ,
230+ (fftwf_complex*)aux2 , nrank, this ->nplane , 1 , FFTW_BACKWARD, FFTW_MEASURE);
196231 }
197232 destroypf = false ;
198233}
@@ -213,17 +248,19 @@ void FFT :: initplanf_mpi()
213248void FFT:: cleanFFT()
214249{
215250 if (destroyp==true ) return ;
216- fftw_destroy_plan (plan1for);
217- fftw_destroy_plan (plan1bac);
251+ fftw_destroy_plan (planzfor);
252+ fftw_destroy_plan (planzbac);
253+ fftw_destroy_plan (planxfor);
254+ fftw_destroy_plan (planxbac);
218255 if (this ->gamma_only )
219256 {
220- fftw_destroy_plan (plan2r2c );
221- fftw_destroy_plan (plan2c2r );
257+ fftw_destroy_plan (planyr2c );
258+ fftw_destroy_plan (planyc2r );
222259 }
223260 else
224261 {
225- fftw_destroy_plan (plan2for );
226- fftw_destroy_plan (plan2bac );
262+ fftw_destroy_plan (planyfor );
263+ fftw_destroy_plan (planybac );
227264 }
228265 destroyp = true ;
229266
@@ -247,26 +284,107 @@ void FFT:: cleanFFT()
247284 return ;
248285}
249286
250- void FFT::executefftw (std::string instr )
287+ void FFT::fftzfor (std::complex < double >* & in, std:: complex < double >* & out )
251288{
252- if (instr == " 1for" )
253- fftw_execute (this ->plan1for );
254- else if (instr == " 2for" )
255- fftw_execute (this ->plan2for );
256- else if (instr == " 1bac" )
257- fftw_execute (this ->plan1bac );
258- else if (instr == " 2bac" )
259- fftw_execute (this ->plan2bac );
260- else if (instr == " 2r2c" )
261- fftw_execute (this ->plan2r2c );
262- else if (instr == " 2c2r" )
263- fftw_execute (this ->plan2c2r );
264- else
289+ // for(int i = 0 ; i < this->ns ; ++i)
290+ // {
291+ // fftw_execute_dft(this->planzfor,(fftw_complex *)&in[i*nz],(fftw_complex *)&out[i*nz]);
292+ // }
293+ fftw_execute_dft (this ->planzfor ,(fftw_complex *)in,(fftw_complex *)out);
294+ return ;
295+ }
296+
297+ void FFT::fftzbac (std::complex <double >* & in, std::complex <double >* & out)
298+ {
299+ // for(int i = 0 ; i < this->ns ; ++i)
300+ // {
301+ // fftw_execute_dft(this->planzbac,(fftw_complex *)&aux1[i*nz],(fftw_complex *)&aux1[i*nz]);
302+ // }
303+ fftw_execute_dft (this ->planzbac ,(fftw_complex *)in, (fftw_complex *)out);
304+ return ;
305+ }
306+
307+ void FFT::fftxyfor (std::complex <double >* & in, std::complex <double >* & out)
308+ {
309+ int bignpy = this ->nplane * this -> bigny;
310+ for (int i=0 ; i<this ->nx ;++i)
311+ {
312+ fftw_execute_dft ( this ->planyfor , (fftw_complex *)&in[i*bignpy], (fftw_complex *)&out[i*bignpy]);
313+ }
314+
315+ for (int i=0 ; i<=this ->liy ;++i)
316+ {
317+ fftw_execute_dft ( this ->planxfor , (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
318+ }
319+ for (int i=this ->riy ; i<this ->ny ;++i)
320+ {
321+ fftw_execute_dft ( this ->planxfor , (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
322+ }
323+ return ;
324+ }
325+
326+ void FFT::fftxybac (std::complex <double >* & in, std::complex <double >* & out)
327+ {
328+ int bignpy = this ->nplane * this -> bigny;
329+ // x-direction
330+ for (int i=0 ; i<=this ->liy ;++i)
331+ {
332+ fftw_execute_dft ( this ->planxbac , (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
333+ }
334+ for (int i=this ->riy ; i<this ->ny ;++i)
265335 {
266- ModuleBase::WARNING_QUIT ( " FFT " , " Wrong input for excutefftw " );
336+ fftw_execute_dft ( this -> planxbac , (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane] );
267337 }
338+
339+ // //y-direction
340+ for (int i=0 ; i<this ->nx ;++i)
341+ {
342+ fftw_execute_dft ( this ->planybac , (fftw_complex*)&in[i*bignpy], (fftw_complex*)&out[i*bignpy] );
343+ }
344+ return ;
268345}
269346
347+ void FFT::fftxyr2c (double * &in, std::complex <double >* & out)
348+ {
349+ int bignpy = this ->nplane * this -> bigny;
350+
351+ for (int i=0 ; i<this ->nx ;++i)
352+ {
353+ fftw_execute_dft_r2c ( this ->planyr2c , &in[i*bignpy*2 ], (fftw_complex*)&out[i*bignpy] );
354+ }
355+
356+ for (int i=0 ; i<=this ->liy ;++i)
357+ {
358+ fftw_execute_dft ( this ->planxfor , (fftw_complex *)&out[i*nplane], (fftw_complex *)&out[i*nplane]);
359+ }
360+ for (int i=this ->riy ; i<this ->ny ;++i)
361+ {
362+ fftw_execute_dft ( this ->planxfor , (fftw_complex *)&out[i*nplane], (fftw_complex *)&out[i*nplane]);
363+ }
364+ return ;
365+ }
366+
367+
368+ void FFT::fftxyc2r (std::complex <double >* & in, double * & out)
369+ {
370+ int bignpy = this ->nplane * this -> bigny;
371+ for (int i=0 ; i<=this ->liy ;++i)
372+ {
373+ fftw_execute_dft ( this ->planxbac , (fftw_complex *)&in[i*nplane], (fftw_complex *)&in[i*nplane]);
374+ }
375+ for (int i=this ->riy ; i<this ->ny ;++i)
376+ {
377+ fftw_execute_dft ( this ->planxbac , (fftw_complex *)&in[i*nplane], (fftw_complex *)&in[i*nplane]);
378+ }
379+
380+ for (int i=0 ; i<this ->nx ;++i)
381+ {
382+ fftw_execute_dft_c2r ( this ->planyc2r , (fftw_complex*)&in[i*bignpy], &out[i*bignpy*2 ] );
383+ }
384+ return ;
385+ }
386+
387+
270388#ifdef __MIX_PRECISION
271389void executefftwf (std::string instr)
272390{
@@ -282,10 +400,6 @@ void executefftwf(std::string instr)
282400 fftwf_execute (this ->planf2r2c );
283401 else if (instr == " 2c2r" )
284402 fftwf_execute (this ->planf2c2r );
285- else
286- {
287- ModuleBase::WARNING_QUIT (" FFT" , " Wrong input for excutefftwf" );
288- }
289403}
290404#endif
291405}
0 commit comments