Skip to content

Commit 85e2c05

Browse files
authored
Merge pull request #542 from Qianruipku/planewave
optimize the program
2 parents 93cf916 + 4bcbf5b commit 85e2c05

File tree

7 files changed

+96
-99
lines changed

7 files changed

+96
-99
lines changed

source/module_pw/fft.cpp

Lines changed: 73 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "fft.h"
2-
#include "../module_base/tool_quit.h"
32
namespace ModulePW
43
{
54

@@ -55,8 +54,9 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
5554
if(!this->mpifft)
5655
{
5756
//It seems in-place fft is faster than out-of-place fft
58-
if(this->nproc == 1) c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
59-
else c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
57+
// if(this->nproc == 1) c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
58+
// else c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
59+
c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
6060
//c_gspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
6161
if(this->gamma_only)
6262
{
@@ -71,9 +71,11 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
7171
}
7272
else
7373
{
74-
if(this->nproc == 1) c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
75-
else c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
74+
// if(this->nproc == 1) c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
75+
// else c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
76+
c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
7677
}
78+
c_gspace2 = c_rspace;
7779
#ifdef __MIX_PRECISION
7880
cf_gspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->nz * this->ns);
7981
cf_rspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->bignxy * nplane);
@@ -121,13 +123,14 @@ void FFT :: initplan()
121123
// fftw_complex *in, const int *inembed, int istride, int idist,
122124
// fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned flags);
123125

126+
//It is better to use out-of-place fft for stride = 1.
124127
this->planzfor = fftw_plan_many_dft( 1, &this->nz, this->ns,
125128
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
126-
(fftw_complex*) c_gspace, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
129+
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
127130

128131
this->planzbac = fftw_plan_many_dft( 1, &this->nz, this->ns,
129132
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
130-
(fftw_complex*) c_gspace, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
133+
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
131134

132135
// this->planzfor = fftw_plan_dft_1d(this->nz,(fftw_complex*) c_gspace,(fftw_complex*) c_gspace, FFTW_FORWARD, FFTW_MEASURE);
133136
// this->planzbac = fftw_plan_dft_1d(this->nz,(fftw_complex*) c_gspace,(fftw_complex*) c_gspace,FFTW_BACKWARD, FFTW_MEASURE);
@@ -136,7 +139,7 @@ void FFT :: initplan()
136139
// 2 D
137140
//---------------------------------------------------------
138141

139-
int nrank[2] = {this->nx,this->bigny};
142+
//int nrank[2] = {this->nx,this->bigny};
140143
int *embed = NULL;
141144
if(this->gamma_only)
142145
{
@@ -184,6 +187,8 @@ void FFT :: initplan()
184187
// this->plan2bac = fftw_plan_many_dft( 2, nrank, this->nplane,
185188
// (fftw_complex*) c_rspace, embed, this->nplane, 1,
186189
// (fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
190+
191+
//It is better to use in-place for stride > 1
187192
int npy = this->nplane * this->ny;
188193
this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)c_rspace, embed, npy, 1,
189194
(fftw_complex *)c_rspace, embed, npy, 1, FFTW_FORWARD, FFTW_MEASURE );
@@ -301,80 +306,78 @@ void FFT:: cleanFFT()
301306
return;
302307
}
303308

304-
void FFT::executefftw(std::string instr)
309+
void FFT::fftzfor(std::complex<double>* & in, std::complex<double>* & out)
305310
{
306-
if(instr == "1for")
307-
{
308-
// for(int i = 0 ; i < this->ns ; ++i)
309-
// {
310-
// fftw_execute_dft(this->planzfor,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
311-
// }
312-
fftw_execute_dft(this->planzfor,(fftw_complex *)c_gspace,(fftw_complex *)c_gspace);
313-
// fftw_execute(this->planzfor);
314-
}
315-
else if(instr == "1bac")
316-
{
317-
// for(int i = 0 ; i < this->ns ; ++i)
318-
// {
319-
// fftw_execute_dft(this->planzbac,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
320-
// }
321-
fftw_execute_dft(this->planzbac,(fftw_complex *)c_gspace,(fftw_complex *)c_gspace);
322-
// fftw_execute(this->planzbac);
323-
}
324-
else if(instr == "2for")
311+
// for(int i = 0 ; i < this->ns ; ++i)
312+
// {
313+
// fftw_execute_dft(this->planzfor,(fftw_complex *)&in[i*nz],(fftw_complex *)&out[i*nz]);
314+
// }
315+
fftw_execute_dft(this->planzfor,(fftw_complex *)in,(fftw_complex *)out);
316+
return;
317+
}
318+
319+
void FFT::fftzbac(std::complex<double>* & in, std::complex<double>* & out)
320+
{
321+
// for(int i = 0 ; i < this->ns ; ++i)
322+
// {
323+
// fftw_execute_dft(this->planzbac,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
324+
// }
325+
fftw_execute_dft(this->planzbac,(fftw_complex *)in, (fftw_complex *)out);
326+
return;
327+
}
328+
329+
void FFT::fftxyfor(std::complex<double>* & in, std::complex<double>* & out)
330+
{
331+
int npy = this->nplane * this-> ny;
332+
fftw_execute_dft( this->planxfor, (fftw_complex *)in, (fftw_complex *)out);
333+
for (int i=0; i<this->nx;++i)
325334
{
326-
int npy = this->nplane * this-> ny;
327-
fftw_execute_dft( this->planxfor, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
328-
for (int i=0; i<this->nx;++i)
329-
{
330-
fftw_execute_dft( this->planyfor, (fftw_complex*)&c_rspace[i*npy], (fftw_complex*)&c_rspace[i*npy] );
331-
}
332-
// fftw_execute(this->plan2for);
335+
fftw_execute_dft( this->planyfor, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
333336
}
334-
else if(instr == "2bac")
335-
{
336-
// fftw_execute(this->plan2bac);
337-
int npy = this->nplane * this-> ny;
338-
339-
for (int i=0; i<this->nx;++i)
340-
{
341-
fftw_execute_dft( this->planybac, (fftw_complex*)&c_rspace[i*npy], (fftw_complex*)&c_rspace[i*npy] );
342-
}
343-
fftw_execute_dft( this->planxbac, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
337+
return;
338+
}
339+
340+
void FFT::fftxybac(std::complex<double>* & in, std::complex<double>* & out)
341+
{
342+
int npy = this->nplane * this-> ny;
344343

345-
}
346-
else if(instr == "2r2c")
344+
for (int i=0; i<this->nx;++i)
347345
{
348-
// fftw_execute(this->plan2r2c);
349-
//int npy = this->nplane * this-> ny;
350-
int bignpy = this->nplane * this-> bigny;
351-
// int padnpy = this->nplane * this-> ny * 2;
352-
for (int i=0; i<this->nx;++i)
353-
{
354-
fftw_execute_dft_r2c( this->planyr2c, &r_rspace[i*bignpy], (fftw_complex*)&c_rspace[i*bignpy] );
355-
// fftw_execute_dft_r2c( this->planyfor, &r_rspace[4*i*padnpy], (fftw_complex*)&c_rspace[i*padnpy] );
356-
}
357-
fftw_execute_dft( this->planxfor, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
346+
fftw_execute_dft( this->planybac, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
358347
}
359-
else if(instr == "2c2r")
348+
fftw_execute_dft( this->planxbac, (fftw_complex *)in, (fftw_complex *)out);
349+
return;
350+
}
351+
352+
void FFT::fftxyr2c(double* &in, std::complex<double>* & out)
353+
{
354+
//int npy = this->nplane * this-> ny;
355+
int bignpy = this->nplane * this-> bigny;
356+
// int padnpy = this->nplane * this-> ny * 2;
357+
for (int i=0; i<this->nx;++i)
360358
{
361-
// fftw_execute(this->plan2c2r);
362-
//int npy = this->nplane * this-> ny;
363-
int bignpy = this->nplane * this-> bigny;
364-
// int padnpy = this->nplane * this-> ny * 2;
365-
fftw_execute_dft( this->planxbac, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
366-
for (int i=0; i<this->nx;++i)
367-
{
368-
fftw_execute_dft_c2r( this->planyc2r, (fftw_complex*)&c_rspace[i*bignpy], &r_rspace[i*bignpy] );
369-
// fftw_execute_dft_c2r( this->planybac, (fftw_complex*)&c_rspace[i*padnpy], &r_rspace[4*i*padnpy] );
370-
}
359+
fftw_execute_dft_r2c( this->planyr2c, &in[i*bignpy], (fftw_complex*)&out[i*bignpy] );
360+
// fftw_execute_dft_r2c( this->planyfor, &r_rspace[4*i*padnpy], (fftw_complex*)&c_rspace[i*padnpy] );
371361
}
372-
else
362+
fftw_execute_dft( this->planxfor, (fftw_complex *)out, (fftw_complex *)out);
363+
return;
364+
}
365+
366+
void FFT::fftxyc2r(std::complex<double>* & in, double* & out)
367+
{
368+
//int npy = this->nplane * this-> ny;
369+
int bignpy = this->nplane * this-> bigny;
370+
// int padnpy = this->nplane * this-> ny * 2;
371+
fftw_execute_dft( this->planxbac, (fftw_complex *)in, (fftw_complex *)in);
372+
for (int i=0; i<this->nx;++i)
373373
{
374-
ModuleBase::WARNING_QUIT("FFT", "Wrong input for excutefftw");
374+
fftw_execute_dft_c2r( this->planyc2r, (fftw_complex*)&in[i*bignpy], &out[i*bignpy] );
375+
// fftw_execute_dft_c2r( this->planybac, (fftw_complex*)&c_rspace[i*padnpy], &r_rspace[4*i*padnpy] );
375376
}
377+
return;
376378
}
377379

380+
378381
#ifdef __MIX_PRECISION
379382
void executefftwf(std::string instr)
380383
{
@@ -390,10 +393,6 @@ void executefftwf(std::string instr)
390393
fftwf_execute(this->planf2r2c);
391394
else if(instr == "2c2r")
392395
fftwf_execute(this->planf2c2r);
393-
else
394-
{
395-
ModuleBase::WARNING_QUIT("FFT", "Wrong input for excutefftwf");
396-
}
397396
}
398397
#endif
399398
}

source/module_pw/fft.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ class FFT
3131
void setupFFT();
3232
void cleanFFT();
3333

34-
void executefftw(std::string instr);
34+
void fftzfor(std::complex<double>* & in, std::complex<double>* & out);
35+
void fftzbac(std::complex<double>* & in, std::complex<double>* & out);
36+
void fftxyfor(std::complex<double>* & in, std::complex<double>* & out);
37+
void fftxybac(std::complex<double>* & in, std::complex<double>* & out);
38+
void fftxyr2c(double * &in, std::complex<double>* & out);
39+
void fftxyc2r(std::complex<double>* & in, double* & out);
40+
3541
#ifdef __MIX_PRECISION
3642
void executefftwf(std::string instr);
3743
#endif

source/module_pw/pw_distributeg_method1.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ void PW_Basis::divide_sticks(
312312
// find the ip of core containing the least planewaves.
313313
for (int ip = 0; ip < this->poolnproc; ++ip)
314314
{
315-
const int non_zero_grid = nst_per[ip] * this->nz; // number of reciprocal planewaves on this core.
315+
//const int non_zero_grid = nst_per[ip] * this->nz; // number of reciprocal planewaves on this core.
316316
const int npwmin = npw_per[ipmin];
317317
const int npw_ip = npw_per[ip];
318318
const int nstmin = nst_per[ipmin];

source/module_pw/pw_distributeg_method2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void PW_Basis::distribution_method2()
2727
// initial the variables needed by all proc.
2828
int tot_npw = 0; // total number of planewaves.
2929
this->nstot = 0; // total number of sticks.
30-
int st_start = 0; // index of the first stick on current proc.
30+
// int st_start = 0; // index of the first stick on current proc.
3131
int *st_bottom2D = NULL; // st_bottom2D[ixy], minimum z of stick on (x, y).
3232
int *st_length2D = NULL; // st_length2D[ixy], number of planewaves in stick on (x, y).
3333

source/module_pw/pw_transform.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ void PW_Basis:: real2recip(std::complex<double> * in, std::complex<double> * out
1717
assert(this->gamma_only == false);
1818
for(int ir = 0 ; ir < this->nrxx ; ++ir)
1919
{
20-
this->ft.c_rspace[ir] = in[ir];
20+
this->ft.c_gspace[ir] = in[ir];
2121
}
22-
this->ft.executefftw("2for");
22+
this->ft.fftxyfor(ft.c_gspace,ft.c_gspace);
2323

24-
this->gatherp_scatters2(this->ft.c_rspace, this->ft.c_gspace);
24+
this->gatherp_scatters2(this->ft.c_gspace, this->ft.c_gspace2);
2525

26-
this->ft.executefftw("1for");
26+
this->ft.fftzfor(ft.c_gspace2,ft.c_gspace);
2727

2828
for(int ig = 0 ; ig < this->npw ; ++ig)
2929
{
@@ -58,11 +58,11 @@ void PW_Basis:: real2recip(double * in, std::complex<double> * out)
5858
// }
5959
// }
6060

61-
this->ft.executefftw("2r2c");
61+
this->ft.fftxyr2c(ft.r_rspace,ft.c_gspace);
6262

63-
this->gatherp_scatters2(this->ft.c_rspace, this->ft.c_gspace);
63+
this->gatherp_scatters2(this->ft.c_gspace, this->ft.c_gspace2);
6464

65-
this->ft.executefftw("1for");
65+
this->ft.fftzfor(ft.c_gspace2,ft.c_gspace);
6666

6767
for(int ig = 0 ; ig < this->npw ; ++ig)
6868
{
@@ -85,15 +85,15 @@ void PW_Basis:: recip2real(std::complex<double> * in, std::complex<double> * out
8585
{
8686
this->ft.c_gspace[this->ig2isz[ig]] = in[ig];
8787
}
88-
this->ft.executefftw("1bac");
88+
this->ft.fftzbac(ft.c_gspace, ft.c_gspace2);
8989

90-
this->gathers_scatterp2(this->ft.c_gspace,this->ft.c_rspace);
90+
this->gathers_scatterp2(this->ft.c_gspace2,this->ft.c_gspace);
9191

92-
this->ft.executefftw("2bac");
92+
this->ft.fftxybac(ft.c_gspace,ft.c_gspace);
9393

9494
for(int ir = 0 ; ir < this->nrxx ; ++ir)
9595
{
96-
out[ir] = this->ft.c_rspace[ir] / double(this->bignxyz);
96+
out[ir] = this->ft.c_gspace[ir] / double(this->bignxyz);
9797
}
9898

9999
return;
@@ -113,11 +113,11 @@ void PW_Basis:: recip2real(std::complex<double> * in, double * out)
113113
{
114114
this->ft.c_gspace[this->ig2isz[ig]] = in[ig];
115115
}
116-
this->ft.executefftw("1bac");
116+
this->ft.fftzbac(ft.c_gspace, ft.c_gspace2);
117117

118-
this->gathers_scatterp2(this->ft.c_gspace, this->ft.c_rspace);
118+
this->gathers_scatterp2(this->ft.c_gspace2, this->ft.c_gspace);
119119

120-
this->ft.executefftw("2c2r");
120+
this->ft.fftxyc2r(ft.c_gspace,ft.r_rspace);
121121

122122
for(int ir = 0 ; ir < this->nrxx ; ++ir)
123123
{

source/module_pw/unittest/test_t1.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,7 @@ int main(int argc,char **argv)
5151
nx = pwtest.nx;
5252
ny = pwtest.ny;
5353
nz = pwtest.nz;
54-
int nplane = pwtest.nplane;
55-
int nxyz = nx * ny * nz;
5654
if(myrank == 0) cout<<"FFT: "<<nx<<" "<<ny<<" "<<nz<<endl;
57-
double tpiba2 = ModuleBase::TWO_PI * ModuleBase::TWO_PI / lat0 / lat0;
58-
double ggecut = wfcecut / tpiba2;
5955
ModuleBase::Matrix3 GT,G,GGT;
6056
GT = latvec.Inverse();
6157
G = GT.Transpose();

source/module_pw/unittest/test_t2.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,7 @@ int main(int argc,char **argv)
5151
nx = pwtest.nx;
5252
ny = pwtest.ny;
5353
nz = pwtest.nz;
54-
int nplane = pwtest.nplane;
55-
int nxyz = nx * ny * nz;
5654
if(myrank == 0) cout<<"FFT: "<<nx<<" "<<ny<<" "<<nz<<endl;
57-
double tpiba2 = ModuleBase::TWO_PI * ModuleBase::TWO_PI / lat0 / lat0;
58-
double ggecut = wfcecut / tpiba2;
5955
ModuleBase::Matrix3 GT,G,GGT;
6056
GT = latvec.Inverse();
6157
G = GT.Transpose();

0 commit comments

Comments
 (0)