Skip to content

Commit 4bcbf5b

Browse files
committed
< prof >
change in-place fft to out-of-place fft for fftz
1 parent 1c58b1f commit 4bcbf5b

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

source/module_pw/fft.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
5454
if(!this->mpifft)
5555
{
5656
//It seems in-place fft is faster than out-of-place fft
57-
if(this->nproc == 1) c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
58-
else c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
57+
// if(this->nproc == 1) c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
58+
// else c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
59+
c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
5960
//c_gspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
6061
if(this->gamma_only)
6162
{
@@ -70,9 +71,11 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
7071
}
7172
else
7273
{
73-
if(this->nproc == 1) c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
74-
else c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
74+
// if(this->nproc == 1) c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
75+
// else c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
76+
c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
7577
}
78+
c_gspace2 = c_rspace;
7679
#ifdef __MIX_PRECISION
7780
cf_gspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->nz * this->ns);
7881
cf_rspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->bignxy * nplane);
@@ -120,13 +123,14 @@ void FFT :: initplan()
120123
// fftw_complex *in, const int *inembed, int istride, int idist,
121124
// fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned flags);
122125

126+
//It is better to use out-of-place fft for stride = 1.
123127
this->planzfor = fftw_plan_many_dft( 1, &this->nz, this->ns,
124128
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
125-
(fftw_complex*) c_gspace, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
129+
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
126130

127131
this->planzbac = fftw_plan_many_dft( 1, &this->nz, this->ns,
128132
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
129-
(fftw_complex*) c_gspace, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
133+
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
130134

131135
// this->planzfor = fftw_plan_dft_1d(this->nz,(fftw_complex*) c_gspace,(fftw_complex*) c_gspace, FFTW_FORWARD, FFTW_MEASURE);
132136
// this->planzbac = fftw_plan_dft_1d(this->nz,(fftw_complex*) c_gspace,(fftw_complex*) c_gspace,FFTW_BACKWARD, FFTW_MEASURE);
@@ -183,6 +187,8 @@ void FFT :: initplan()
183187
// this->plan2bac = fftw_plan_many_dft( 2, nrank, this->nplane,
184188
// (fftw_complex*) c_rspace, embed, this->nplane, 1,
185189
// (fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
190+
191+
//It is better to use in-place for stride > 1
186192
int npy = this->nplane * this->ny;
187193
this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, npy, 1,
188194
(fftw_complex *)c_rspace, embed, npy, 1, FFTW_FORWARD, FFTW_MEASURE );

source/module_pw/pw_distributeg_method1.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ void PW_Basis::divide_sticks(
312312
// find the ip of core containing the least planewaves.
313313
for (int ip = 0; ip < this->poolnproc; ++ip)
314314
{
315-
const int non_zero_grid = nst_per[ip] * this->nz; // number of reciprocal planewaves on this core.
315+
//const int non_zero_grid = nst_per[ip] * this->nz; // number of reciprocal planewaves on this core.
316316
const int npwmin = npw_per[ipmin];
317317
const int npw_ip = npw_per[ip];
318318
const int nstmin = nst_per[ipmin];

source/module_pw/pw_transform.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ void PW_Basis:: real2recip(std::complex<double> * in, std::complex<double> * out
1717
assert(this->gamma_only == false);
1818
for(int ir = 0 ; ir < this->nrxx ; ++ir)
1919
{
20-
this->ft.c_rspace[ir] = in[ir];
20+
this->ft.c_gspace[ir] = in[ir];
2121
}
22-
this->ft.fftxyfor(ft.c_rspace,ft.c_rspace);
22+
this->ft.fftxyfor(ft.c_gspace,ft.c_gspace);
2323

24-
this->gatherp_scatters2(this->ft.c_rspace, this->ft.c_gspace);
24+
this->gatherp_scatters2(this->ft.c_gspace, this->ft.c_gspace2);
2525

26-
this->ft.fftzfor(ft.c_gspace,ft.c_gspace);
26+
this->ft.fftzfor(ft.c_gspace2,ft.c_gspace);
2727

2828
for(int ig = 0 ; ig < this->npw ; ++ig)
2929
{
@@ -58,11 +58,11 @@ void PW_Basis:: real2recip(double * in, std::complex<double> * out)
5858
// }
5959
// }
6060

61-
this->ft.fftxyr2c(ft.r_rspace,ft.c_rspace);
61+
this->ft.fftxyr2c(ft.r_rspace,ft.c_gspace);
6262

63-
this->gatherp_scatters2(this->ft.c_rspace, this->ft.c_gspace);
63+
this->gatherp_scatters2(this->ft.c_gspace, this->ft.c_gspace2);
6464

65-
this->ft.fftzfor(ft.c_gspace,ft.c_gspace);
65+
this->ft.fftzfor(ft.c_gspace2,ft.c_gspace);
6666

6767
for(int ig = 0 ; ig < this->npw ; ++ig)
6868
{
@@ -85,15 +85,15 @@ void PW_Basis:: recip2real(std::complex<double> * in, std::complex<double> * out
8585
{
8686
this->ft.c_gspace[this->ig2isz[ig]] = in[ig];
8787
}
88-
this->ft.fftzbac(ft.c_gspace, ft.c_gspace);
88+
this->ft.fftzbac(ft.c_gspace, ft.c_gspace2);
8989

90-
this->gathers_scatterp2(this->ft.c_gspace,this->ft.c_rspace);
90+
this->gathers_scatterp2(this->ft.c_gspace2,this->ft.c_gspace);
9191

92-
this->ft.fftxybac(ft.c_rspace,ft.c_rspace);
92+
this->ft.fftxybac(ft.c_gspace,ft.c_gspace);
9393

9494
for(int ir = 0 ; ir < this->nrxx ; ++ir)
9595
{
96-
out[ir] = this->ft.c_rspace[ir] / double(this->bignxyz);
96+
out[ir] = this->ft.c_gspace[ir] / double(this->bignxyz);
9797
}
9898

9999
return;
@@ -113,11 +113,11 @@ void PW_Basis:: recip2real(std::complex<double> * in, double * out)
113113
{
114114
this->ft.c_gspace[this->ig2isz[ig]] = in[ig];
115115
}
116-
this->ft.fftzbac(ft.c_gspace, ft.c_gspace);
116+
this->ft.fftzbac(ft.c_gspace, ft.c_gspace2);
117117

118-
this->gathers_scatterp2(this->ft.c_gspace, this->ft.c_rspace);
118+
this->gathers_scatterp2(this->ft.c_gspace2, this->ft.c_gspace);
119119

120-
this->ft.fftxyc2r(ft.c_rspace,ft.r_rspace);
120+
this->ft.fftxyc2r(ft.c_gspace,ft.r_rspace);
121121

122122
for(int ir = 0 ; ir < this->nrxx ; ++ir)
123123
{

0 commit comments

Comments
 (0)