Skip to content

Commit 60f026d

Browse files
authored
Merge pull request #505 from Qianruipku/planewave
change in-place fft to out-of-place
2 parents 528cea6 + 3e115aa commit 60f026d

File tree

4 files changed

+27
-69
lines changed

4 files changed

+27
-69
lines changed

source/module_pw/fft.cpp

Lines changed: 14 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,14 @@ FFT::FFT()
1111
nplane = 1;
1212
mpifft = false;
1313
destroyp = true;
14-
nfftgroup = 1;
1514
gamma_only = false;
1615
c_rspace = c_gspace = NULL;
16+
c_rspace2 = c_gspace2 = NULL;
1717
r_rspace = NULL;
18-
//r_gspace = NULL
1918
#ifdef __MIX_PRECISION
2019
destroypf = true;
2120
cf_rspace = cf_gspace = NULL;
2221
rf_rspace = NULL;
23-
//rf_gspace = NULL
2422
#endif
2523
}
2624

@@ -29,17 +27,17 @@ FFT::~FFT()
2927
this->cleanFFT();
3028
if(c_gspace!=NULL) fftw_free(c_gspace);
3129
if(c_rspace!=NULL) fftw_free(c_rspace);
32-
//if(r_gspace!=NULL) fftw_free(r_gspace);
30+
if(c_gspace2!=NULL) fftw_free(c_gspace2);
31+
if(c_rspace2!=NULL) fftw_free(c_rspace2);
3332
if(r_rspace!=NULL) fftw_free(r_rspace);
3433
#ifdef __MIX_PRECISION
3534
if(cf_gspace!=NULL) fftw_free(cf_gspace);
3635
if(cf_rspace!=NULL) fftw_free(cf_rspace);
37-
//if(rf_gspace!=NULL) fftw_free(rf_gspace);
3836
if(rf_rspace!=NULL) fftw_free(rf_rspace);
3937
#endif
4038
}
4139

42-
void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, int nfftgroup_in, bool gamma_only_in, bool mpifft_in)
40+
void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, bool gamma_only_in, bool mpifft_in)
4341
{
4442
this->gamma_only = gamma_only_in;
4543
this->nx = nx_in;
@@ -52,19 +50,22 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
5250
this->mpifft = mpifft_in;
5351
this->nxy = this->nx * this-> ny;
5452
this->bignxy = this->nx * this->bigny;
55-
this->nfftgroup = nfftgroup_in;
5653
if(!this->mpifft)
5754
{
55+
//out-of-place fft is faster than in-place fft
5856
c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
59-
//r_gspace = (double *) fftw_malloc(sizeof(double) * this->nz * this->ns);
57+
c_gspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
6058
c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
6159
if(this->gamma_only)
6260
{
6361
r_rspace = (double *) fftw_malloc(sizeof(double) * this->bignxy * nplane);
6462
}
63+
else
64+
{
65+
c_rspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
66+
}
6567
#ifdef __MIX_PRECISION
6668
cf_gspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->nz * this->ns);
67-
//rf_gspace = (float *)fftw_malloc(sizeof(float) * this->nz * this->ns);
6869
cf_rspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->bignxy * nplane);
6970
if(this->gamma_only)
7071
{
@@ -112,19 +113,12 @@ void FFT :: initplan()
112113

113114
this->plan1for = fftw_plan_many_dft( 1, &this->nz, this->ns,
114115
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
115-
(fftw_complex*) c_gspace, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
116+
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
116117

117118
this->plan1bac = fftw_plan_many_dft( 1, &this->nz, this->ns,
118119
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
119-
(fftw_complex*) c_gspace, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
120+
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
120121

121-
// this->plan1r2c = fftw_plan_many_dft_r2c( 1, &this->nz, this->ns,
122-
// r_gspace, &this->nz, 1, this->nz,
123-
// (fftw_complex*) c_gspace, &this->nz, 1, this->nz, FFTW_MEASURE);
124-
125-
// this->plan1c2r = fftw_plan_many_dft_c2r( 1, &this->nz, this->ns,
126-
// (fftw_complex*) c_gspace, &this->nz, 1, this->nz,
127-
// r_gspace, &this->nz, 1, this->nz, FFTW_MEASURE);
128122

129123
//---------------------------------------------------------
130124
// 2 D
@@ -134,11 +128,11 @@ void FFT :: initplan()
134128
int *embed = NULL;
135129
this->plan2for = fftw_plan_many_dft( 2, nrank, this->nplane,
136130
(fftw_complex*) c_rspace, embed, this->nplane, 1,
137-
(fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
131+
(fftw_complex*) c_rspace2, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
138132

139133
this->plan2bac = fftw_plan_many_dft( 2, nrank, this->nplane,
140134
(fftw_complex*) c_rspace, embed, this->nplane, 1,
141-
(fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
135+
(fftw_complex*) c_rspace2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
142136

143137
this->plan2r2c = fftw_plan_many_dft_r2c( 2, nrank, this->nplane,
144138
r_rspace, embed, this->nplane, 1,
@@ -169,13 +163,6 @@ void FFT :: initplanf()
169163
(fftwf_complex*)c_gspace, &this->nz, 1, this->nz,
170164
(fftwf_complex*)c_gspace, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
171165

172-
// this->planf1r2c = fftwf_plan_many_dft_r2c( 1, &this->nz, this->ns,
173-
// r_gspace, &this->nz, 1, this->nz,
174-
// (fftwf_complex*)c_gspace, &this->nz, 1, this->nz, FFTW_MEASURE);
175-
176-
// this->planf1c2r = fftwf_plan_many_dft_c2r( 1, &this->nz, this->ns,
177-
// (fftwf_complex*)c_gspace, &this->nz, 1, this->nz,
178-
// r_gspace, &this->nz, 1, this->nz, FFTW_MEASURE);
179166

180167
//---------------------------------------------------------
181168
// 2 D
@@ -218,8 +205,6 @@ void FFT:: cleanFFT()
218205
if(destroyp==true) return;
219206
fftw_destroy_plan(plan1for);
220207
fftw_destroy_plan(plan1bac);
221-
// fftw_destroy_plan(plan1r2c);
222-
// fftw_destroy_plan(plan1c2r);
223208
fftw_destroy_plan(plan2for);
224209
fftw_destroy_plan(plan2bac);
225210
fftw_destroy_plan(plan2r2c);
@@ -230,8 +215,6 @@ void FFT:: cleanFFT()
230215
if(destroypf==true) return;
231216
fftw_destroy_plan(planf1for);
232217
fftw_destroy_plan(planf1bac);
233-
// fftw_destroy_plan(planf1r2c);
234-
// fftw_destroy_plan(planf1c2r);
235218
fftw_destroy_plan(planf2for);
236219
fftw_destroy_plan(planf2bac);
237220
fftw_destroy_plan(planf2r2c);
@@ -252,12 +235,8 @@ void FFT::executefftw(std::string instr)
252235
fftw_execute(this->plan1bac);
253236
else if(instr == "2bac")
254237
fftw_execute(this->plan2bac);
255-
// else if(instr == "1r2c")
256-
// fftw_execute(this->plan1r2c);
257238
else if(instr == "2r2c")
258239
fftw_execute(this->plan2r2c);
259-
// else if(instr == "1c2r")
260-
// fftw_execute_dft(this->plan1c2r);
261240
else if(instr == "2c2r")
262241
fftw_execute(this->plan2c2r);
263242
else
@@ -277,12 +256,8 @@ void executefftwf(std::string instr)
277256
fftwf_execute(this->planf1bac);
278257
else if(instr == "2bac")
279258
fftwf_execute(this->planf2bac);
280-
// else if(instr == "1r2c")
281-
// fftwf_execute(this->planf1r2c);
282259
else if(instr == "2r2c")
283260
fftwf_execute(this->planf2r2c);
284-
// else if(instr == "1c2r")
285-
// fftwf_execute(this->planf1c2r);
286261
else if(instr == "2c2r")
287262
fftwf_execute(this->planf2c2r);
288263
else

source/module_pw/fft.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class FFT
2727

2828
FFT();
2929
~FFT();
30-
void initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, int nfft_in, bool gamma_only_in, bool mpifft_in = false);
30+
void initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, bool gamma_only_in, bool mpifft_in = false);
3131
void setupFFT();
3232
void cleanFFT();
3333

@@ -51,15 +51,12 @@ class FFT
5151
int bignxy;
5252
int ns; //number of sticks
5353
int nplane; //number of x-y planes
54-
int nfftgroup; // number of fft in a group
55-
std::complex<double> * c_gspace; //complex number space for g, [ns * nz]
56-
std::complex<double> * c_rspace; //complex number space for r, [nplane * nx *ny]
57-
//double *r_gspace; //real number space for g, [ns * nz]
54+
std::complex<double> * c_gspace, *c_gspace2; //complex number space for g, [ns * nz]
55+
std::complex<double> * c_rspace, *c_rspace2;//complex number space for r, [nplane * nx *ny]
5856
double *r_rspace; //real number space for r, [nplane * nx *ny]
5957
#ifdef __MIX_PRECISION
6058
std::complex<float> * cf_gspace; //complex number space for g, [ns * nz]
6159
std::complex<float> * cf_rspace; //complex number space for r, [nplane * nx *ny]
62-
//float *rf_gspace; //real number space for g, [ns * nz]
6360
float *rf_rspace; //real number space for r, [nplane * nx *ny]
6461
#endif
6562

@@ -68,8 +65,6 @@ class FFT
6865
bool gamma_only;
6966
bool destroyp;
7067
bool mpifft; // if use mpi fft, only used when define __FFTW3_MPI
71-
//fftw_plan plan1_r2c;
72-
//fftw_plan plan1_c2r;
7368
fftw_plan plan2r2c;
7469
fftw_plan plan2c2r;
7570
fftw_plan plan1for;
@@ -78,8 +73,6 @@ class FFT
7873
fftw_plan plan2bac;
7974
#ifdef __MIX_PRECISION
8075
bool destroypf;
81-
//fftwf_plan planf1_r2c;
82-
//fftwf_plan planf1_c2r;
8376
fftwf_plan planf2r2c;
8477
fftwf_plan planf2c2r;
8578
fftwf_plan planf1for;

source/module_pw/pw_basis.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@ namespace ModulePW
88
PW_Basis::PW_Basis()
99
{
1010
ig2isz = NULL;
11-
istot2ixy = NULL;
12-
//istot2ixy = NULL;
11+
istot2ixy = NULL;
1312
ixy2istot = NULL;
14-
//is2ixy = NULL;
1513
is2ixy = NULL;
1614
ixy2ip = NULL;
1715
startnsz_per = NULL;
@@ -28,10 +26,8 @@ PW_Basis::PW_Basis()
2826
PW_Basis:: ~PW_Basis()
2927
{
3028
if(ig2isz != NULL) delete[] ig2isz;
31-
//if(istot2ixy != NULL) delete[] istot2ixy;
3229
if(istot2ixy != NULL) delete[] istot2ixy;
3330
if(ixy2istot != NULL) delete[] ixy2istot;
34-
//if(is2ixy != NULL) delete[] is2ixy;
3531
if(is2ixy != NULL) delete[] is2ixy;
3632
if(ixy2ip != NULL) delete[] ixy2ip;
3733
if(startnsz_per != NULL) delete[] startnsz_per;
@@ -47,7 +43,7 @@ void PW_Basis::setuptransform()
4743
{
4844
this->distribute_r();
4945
this->distribute_g();
50-
this->ft.initfft(this->nx,this->bigny,this->nz,this->nst,this->nplane,1,this->gamma_only);
46+
this->ft.initfft(this->nx,this->bigny,this->nz,this->nst,this->nplane,this->gamma_only);
5147
this->ft.setupFFT();
5248
}
5349

source/module_pw/pw_transform.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ void PW_Basis:: real2recip(std::complex<double> * in, std::complex<double> * out
2121
}
2222
this->ft.executefftw("2for");
2323

24-
this->gatherp_scatters(this->ft.c_rspace, this->ft.c_gspace);
24+
this->gatherp_scatters(this->ft.c_rspace2, this->ft.c_gspace);
2525

2626
this->ft.executefftw("1for");
2727

2828
for(int ig = 0 ; ig < this->npw ; ++ig)
2929
{
30-
out[ig] = this->ft.c_gspace[this->ig2isz[ig]];
30+
out[ig] = this->ft.c_gspace2[this->ig2isz[ig]];
3131
}
3232
return;
3333
}
@@ -52,7 +52,7 @@ void PW_Basis:: real2recip(double * in, std::complex<double> * out)
5252

5353
for(int ig = 0 ; ig < this->npw ; ++ig)
5454
{
55-
out[ig] = this->ft.c_gspace[this->ig2isz[ig]];
55+
out[ig] = this->ft.c_gspace2[this->ig2isz[ig]];
5656
}
5757
return;
5858
}
@@ -66,23 +66,20 @@ void PW_Basis:: recip2real(std::complex<double> * in, std::complex<double> * out
6666
{
6767
assert(this->gamma_only == false);
6868
ModuleBase::GlobalFunc::ZEROS(ft.c_gspace, this->nst * this->nz);
69-
// for(int igg = 0 ; igg < this->nst * this->nz ; ++igg)
70-
// {
71-
// this->ft.c_gspace[igg] = 0.0;
72-
// }
69+
7370
for(int ig = 0 ; ig < this->npw ; ++ig)
7471
{
7572
this->ft.c_gspace[this->ig2isz[ig]] = in[ig];
7673
}
7774
this->ft.executefftw("1bac");
7875

79-
this->gathers_scatterp(this->ft.c_gspace,this->ft.c_rspace);
76+
this->gathers_scatterp(this->ft.c_gspace2,this->ft.c_rspace);
8077

8178
this->ft.executefftw("2bac");
8279

8380
for(int ir = 0 ; ir < this->nrxx ; ++ir)
8481
{
85-
out[ir] = this->ft.c_rspace[ir] / this->bignxyz;
82+
out[ir] = this->ft.c_rspace2[ir] / this->bignxyz;
8683
}
8784
return;
8885
}
@@ -96,17 +93,14 @@ void PW_Basis:: recip2real(std::complex<double> * in, double * out)
9693
{
9794
assert(this->gamma_only == true);
9895
ModuleBase::GlobalFunc::ZEROS(ft.c_gspace, this->nst * this->nz);
99-
// for(int igg = 0 ; igg < this->nst * this->nz ; ++igg)
100-
// {
101-
// this->ft.c_gspace[igg] = 0.0;
102-
// }
96+
10397
for(int ig = 0 ; ig < this->npw ; ++ig)
10498
{
10599
this->ft.c_gspace[this->ig2isz[ig]] = in[ig];
106100
}
107101
this->ft.executefftw("1bac");
108102

109-
this->gathers_scatterp(this->ft.c_gspace, this->ft.c_rspace);
103+
this->gathers_scatterp(this->ft.c_gspace2, this->ft.c_rspace);
110104

111105
this->ft.executefftw("2c2r");
112106

0 commit comments

Comments
 (0)