Skip to content

Commit 9763057

Browse files
authored
Merge pull request #548 from Qianruipku/planewave
< prof >
2 parents 3d8788c + bfebade commit 9763057

File tree

5 files changed

+112
-138
lines changed

5 files changed

+112
-138
lines changed

source/module_pw/fft.cpp

Lines changed: 58 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include "fft.h"
2+
#include <iostream>
3+
24
namespace ModulePW
35
{
46

@@ -11,8 +13,7 @@ FFT::FFT()
1113
mpifft = false;
1214
destroyp = true;
1315
gamma_only = false;
14-
c_rspace = c_gspace = NULL;
15-
c_rspace2 = c_gspace2 = NULL;
16+
aux2 = aux1 = NULL;
1617
r_rspace = NULL;
1718
#ifdef __MIX_PRECISION
1819
destroypf = true;
@@ -24,11 +25,8 @@ FFT::FFT()
2425
FFT::~FFT()
2526
{
2627
this->cleanFFT();
27-
if(c_gspace!=NULL) fftw_free(c_gspace);
28-
if(c_rspace!=NULL) fftw_free(c_rspace);
29-
if(r_rspace!=NULL) fftw_free(r_rspace);
30-
// if(c_gspace2!=NULL) fftw_free(c_gspace2);
31-
// if(c_rspace2!=NULL) fftw_free(c_rspace2);
28+
if(aux1!=NULL) fftw_free(aux1);
29+
if(aux2!=NULL) fftw_free(aux2);
3230
#ifdef __MIX_PRECISION
3331
if(cf_gspace!=NULL) fftw_free(cf_gspace);
3432
if(cf_rspace!=NULL) fftw_free(cf_rspace);
@@ -55,29 +53,9 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int lix_in, int rix_in, i
5553
this->maxgrids = (this->nz * this->ns > this->bignxy * nplane) ? this->nz * this->ns : this->bignxy * nplane;
5654
if(!this->mpifft)
5755
{
58-
//It seems in-place fft is faster than out-of-place fft
59-
// if(this->nproc == 1) c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
60-
// else c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
61-
c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
62-
//c_gspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
63-
if(this->gamma_only)
64-
{
65-
c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
66-
r_rspace = (double *) fftw_malloc(sizeof(double) * this->bignxy * nplane);
67-
68-
//r2c in place : It seems in-place r2c/c2r is much slower than out-of-place
69-
// int padnxyp = this->ny * 2 * this->nx * this->nplane;
70-
// c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * padnxyp);
71-
// r_rspace = (double *) c_rspace;
72-
73-
}
74-
else
75-
{
76-
// if(this->nproc == 1) c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
77-
// else c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
78-
c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
79-
}
80-
c_gspace2 = c_rspace;
56+
aux1 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
57+
aux2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
58+
r_rspace = (double *) aux1;
8159
#ifdef __MIX_PRECISION
8260
cf_gspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->nz * this->ns);
8361
cf_rspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->bignxy * nplane);
@@ -127,22 +105,23 @@ void FFT :: initplan()
127105

128106
//It is better to use out-of-place fft for stride = 1.
129107
this->planzfor = fftw_plan_many_dft( 1, &this->nz, this->ns,
130-
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
131-
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
108+
(fftw_complex*) aux1, &this->nz, 1, this->nz,
109+
(fftw_complex*) aux2, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
132110

133111
this->planzbac = fftw_plan_many_dft( 1, &this->nz, this->ns,
134-
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
135-
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
112+
(fftw_complex*) aux1, &this->nz, 1, this->nz,
113+
(fftw_complex*) aux2, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
136114

137-
// this->planzfor = fftw_plan_dft_1d(this->nz,(fftw_complex*) c_gspace,(fftw_complex*) c_gspace, FFTW_FORWARD, FFTW_MEASURE);
138-
// this->planzbac = fftw_plan_dft_1d(this->nz,(fftw_complex*) c_gspace,(fftw_complex*) c_gspace,FFTW_BACKWARD, FFTW_MEASURE);
115+
// this->planzfor = fftw_plan_dft_1d(this->nz,(fftw_complex*) aux1,(fftw_complex*) aux1, FFTW_FORWARD, FFTW_MEASURE);
116+
// this->planzbac = fftw_plan_dft_1d(this->nz,(fftw_complex*) aux1,(fftw_complex*) aux1,FFTW_BACKWARD, FFTW_MEASURE);
139117

140118
//---------------------------------------------------------
141119
// 2 D
142120
//---------------------------------------------------------
143121

144122
//int nrank[2] = {this->nx,this->bigny};
145123
int *embed = NULL;
124+
// It seems 1D+1D is much faster than 2D FFT!
146125
if(this->gamma_only)
147126
{
148127
// int padnpy = this->nplane * this->ny * 2;
@@ -151,55 +130,54 @@ void FFT :: initplan()
151130
// // It seems 1D+1D is much faster than 2D FFT!
152131
// this->plan2r2c = fftw_plan_many_dft_r2c( 2, nrank, this->nplane,
153132
// r_rspace, rankd, this->nplane, 1,
154-
// (fftw_complex*) c_rspace, rankc, this->nplane, 1, FFTW_MEASURE);
133+
// (fftw_complex*) aux2, rankc, this->nplane, 1, FFTW_MEASURE);
155134

156135
// this->plan2c2r = fftw_plan_many_dft_c2r( 2, nrank, this->nplane,
157-
// (fftw_complex*) c_rspace, rankc, this->nplane, 1,
136+
// (fftw_complex*) aux2, rankc, this->nplane, 1,
158137
// r_rspace, rankd, this->nplane, 1, FFTW_MEASURE);
159138

160139
int npy = this->nplane * this->ny;
161140
int bignpy = this->nplane * this->bigny;
162-
this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, bignpy, 1,
163-
(fftw_complex *)c_rspace, embed, bignpy, 1, FFTW_FORWARD, FFTW_MEASURE );
164-
this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, bignpy, 1,
165-
(fftw_complex *)c_rspace, embed, bignpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
141+
this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, bignpy, 1,
142+
(fftw_complex *)aux2, embed, bignpy, 1, FFTW_FORWARD, FFTW_MEASURE );
143+
this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, bignpy, 1,
144+
(fftw_complex *)aux2, embed, bignpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
166145
this->planyr2c = fftw_plan_many_dft_r2c( 1, &this->bigny, this->nplane, r_rspace , embed, this->nplane, 1,
167-
(fftw_complex*)c_rspace, embed, this->nplane, 1, FFTW_MEASURE );
168-
this->planyc2r = fftw_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftw_complex*)c_rspace , embed, this->nplane, 1,
146+
(fftw_complex*)aux1, embed, this->nplane, 1, FFTW_MEASURE );
147+
this->planyc2r = fftw_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftw_complex*)aux1 , embed, this->nplane, 1,
169148
r_rspace, embed, this->nplane, 1, FFTW_MEASURE );
170149

171150
// int padnpy = npy * 2;
172-
// this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, padnpy, 1,
173-
// (fftw_complex *)c_rspace, embed, padnpy, 1, FFTW_FORWARD, FFTW_MEASURE );
174-
// this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, padnpy, 1,
175-
// (fftw_complex *)c_rspace, embed, padnpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
151+
// this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, padnpy, 1,
152+
// (fftw_complex *)aux2, embed, padnpy, 1, FFTW_FORWARD, FFTW_MEASURE );
153+
// this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, padnpy, 1,
154+
// (fftw_complex *)aux2, embed, padnpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
176155
// this->planyr2c = fftw_plan_many_dft_r2c( 1, &this->bigny, this->nplane, r_rspace , embed, this->nplane*2, 1,
177-
// (fftw_complex*)c_rspace, embed, this->nplane, 1, FFTW_MEASURE );
178-
// this->planyc2r = fftw_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftw_complex*)c_rspace , embed, this->nplane, 1,
156+
// (fftw_complex*)aux2, embed, this->nplane, 1, FFTW_MEASURE );
157+
// this->planyc2r = fftw_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
179158
// r_rspace, embed, this->nplane*2, 1, FFTW_MEASURE );
180159

181160
}
182161
else
183162
{
184-
// It seems 1D+1D is much faster than 2D FFT!
185163
// this->plan2for = fftw_plan_many_dft( 2, nrank, this->nplane,
186-
// (fftw_complex*) c_rspace, embed, this->nplane, 1,
187-
// (fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
164+
// (fftw_complex*) aux2, embed, this->nplane, 1,
165+
// (fftw_complex*) aux2, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
188166

189167
// this->plan2bac = fftw_plan_many_dft( 2, nrank, this->nplane,
190-
// (fftw_complex*) c_rspace, embed, this->nplane, 1,
191-
// (fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
168+
// (fftw_complex*) aux2, embed, this->nplane, 1,
169+
// (fftw_complex*) aux2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
192170

193171
//It is better to use in-place for stride > 1
194172
int npy = this->nplane * this->ny;
195-
this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, npy, 1,
196-
(fftw_complex *)c_rspace, embed, npy, 1, FFTW_FORWARD, FFTW_MEASURE );
197-
this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, npy, 1,
198-
(fftw_complex *)c_rspace, embed, npy, 1, FFTW_BACKWARD, FFTW_MEASURE );
199-
this->planyfor = fftw_plan_many_dft( 1, &this->ny, this->nplane, (fftw_complex*)c_rspace , embed, this->nplane, 1,
200-
(fftw_complex*)c_rspace, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE );
201-
this->planybac = fftw_plan_many_dft( 1, &this->ny, this->nplane, (fftw_complex*)c_rspace , embed, this->nplane, 1,
202-
(fftw_complex*)c_rspace, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE );
173+
this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, npy, 1,
174+
(fftw_complex *)aux2, embed, npy, 1, FFTW_FORWARD, FFTW_MEASURE );
175+
this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, npy, 1,
176+
(fftw_complex *)aux2, embed, npy, 1, FFTW_BACKWARD, FFTW_MEASURE );
177+
this->planyfor = fftw_plan_many_dft( 1, &this->ny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
178+
(fftw_complex*)aux2, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE );
179+
this->planybac = fftw_plan_many_dft( 1, &this->ny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
180+
(fftw_complex*)aux2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE );
203181
}
204182

205183

@@ -219,12 +197,12 @@ void FFT :: initplanf()
219197
// fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned flags);
220198

221199
this->planf1for = fftwf_plan_many_dft( 1, &this->nz, this->ns,
222-
(fftwf_complex*)c_gspace, &this->nz, 1, this->nz,
223-
(fftwf_complex*)c_gspace, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
200+
(fftwf_complex*)aux1, &this->nz, 1, this->nz,
201+
(fftwf_complex*)aux1, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
224202

225203
this->planf1bac = fftwf_plan_many_dft( 1, &this->nz, this->ns,
226-
(fftwf_complex*)c_gspace, &this->nz, 1, this->nz,
227-
(fftwf_complex*)c_gspace, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
204+
(fftwf_complex*)aux1, &this->nz, 1, this->nz,
205+
(fftwf_complex*)aux1, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
228206

229207

230208
//---------------------------------------------------------
@@ -237,21 +215,21 @@ void FFT :: initplanf()
237215
{
238216
this->planf2r2c = fftwf_plan_many_dft_r2c( 2, nrank, this->nplane,
239217
r_rspace, nrank, this->nplane, 1,
240-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1, FFTW_MEASURE);
218+
(fftwf_complex*)aux2, nrank, this->nplane, 1, FFTW_MEASURE);
241219

242220
this->planf2c2r = fftwf_plan_many_dft_c2r( 2, nrank, this->nplane,
243-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1,
221+
(fftwf_complex*)aux2, nrank, this->nplane, 1,
244222
r_rspace, nrank, this->nplane, 1, FFTW_MEASURE);
245223
}
246224
else
247225
{
248226
this->planf2for = fftwf_plan_many_dft( 2, nrank, this->nplane,
249-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1,
250-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
227+
(fftwf_complex*)aux2, nrank, this->nplane, 1,
228+
(fftwf_complex*)aux2, nrank, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
251229

252230
this->planf2bac = fftwf_plan_many_dft( 2, nrank, this->nplane,
253-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1,
254-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
231+
(fftwf_complex*)aux2, nrank, this->nplane, 1,
232+
(fftwf_complex*)aux2, nrank, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
255233
}
256234
destroypf = false;
257235
}
@@ -322,7 +300,7 @@ void FFT::fftzbac(std::complex<double>* & in, std::complex<double>* & out)
322300
{
323301
// for(int i = 0 ; i < this->ns ; ++i)
324302
// {
325-
// fftw_execute_dft(this->planzbac,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
303+
// fftw_execute_dft(this->planzbac,(fftw_complex *)&aux1[i*nz],(fftw_complex *)&aux1[i*nz]);
326304
// }
327305
fftw_execute_dft(this->planzbac,(fftw_complex *)in, (fftw_complex *)out);
328306
return;
@@ -366,13 +344,15 @@ void FFT::fftxyr2c(double* &in, std::complex<double>* & out)
366344
// int padnpy = this->nplane * this-> ny * 2;
367345
for (int i=0; i<this->nx;++i)
368346
{
369-
fftw_execute_dft_r2c( this->planyr2c, &in[i*bignpy], (fftw_complex*)&out[i*bignpy] );
370-
// fftw_execute_dft_r2c( this->planyfor, &r_rspace[4*i*padnpy], (fftw_complex*)&c_rspace[i*padnpy] );
347+
fftw_execute_dft_r2c( this->planyr2c, &in[i*bignpy*2], (fftw_complex*)&out[i*bignpy] );
348+
if((double*)&out[i*bignpy] != &in[i*bignpy*2] ) std::cout<<i<<"wrond\n";
349+
// fftw_execute_dft_r2c( this->planyfor, &r_rspace[4*i*padnpy], (fftw_complex*)&aux2[i*padnpy] );
371350
}
372351
fftw_execute_dft( this->planxfor, (fftw_complex *)out, (fftw_complex *)out);
373352
return;
374353
}
375354

355+
376356
void FFT::fftxyc2r(std::complex<double>* & in, double* & out)
377357
{
378358
//int npy = this->nplane * this-> ny;
@@ -381,8 +361,8 @@ void FFT::fftxyc2r(std::complex<double>* & in, double* & out)
381361
fftw_execute_dft( this->planxbac, (fftw_complex *)in, (fftw_complex *)in);
382362
for (int i=0; i<this->nx;++i)
383363
{
384-
fftw_execute_dft_c2r( this->planyc2r, (fftw_complex*)&in[i*bignpy], &out[i*bignpy] );
385-
// fftw_execute_dft_c2r( this->planybac, (fftw_complex*)&c_rspace[i*padnpy], &r_rspace[4*i*padnpy] );
364+
fftw_execute_dft_c2r( this->planyc2r, (fftw_complex*)&in[i*bignpy], &out[i*bignpy*2] );
365+
// fftw_execute_dft_c2r( this->planybac, (fftw_complex*)&aux2[i*padnpy], &r_rspace[4*i*padnpy] );
386366
}
387367
return;
388368
}

source/module_pw/fft.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ class FFT
6161
int nplane; //number of x-y planes
6262
int maxgrids; // max between nz * ns and bignxy * nplane
6363
int nproc; // number of proc.
64-
std::complex<double> * c_gspace, *c_gspace2; //complex number space for g, [ns * nz]
65-
std::complex<double> * c_rspace, *c_rspace2;//complex number space for r, [nplane * nx *ny]
64+
std::complex<double> *aux1, *aux2; //fft space, [maxgrids]
6665
double *r_rspace; //real number space for r, [nplane * nx *ny]
6766
#ifdef __MIX_PRECISION
6867
std::complex<float> * cf_gspace; //complex number space for g, [ns * nz]

0 commit comments

Comments
 (0)