Skip to content

Commit 1c58b1f

Browse files
committed
change executefftw function to a series of functions
1 parent db2350b commit 1c58b1f

File tree

6 files changed

+77
-86
lines changed

6 files changed

+77
-86
lines changed

source/module_pw/fft.cpp

Lines changed: 61 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "fft.h"
2-
#include "../module_base/tool_quit.h"
32
namespace ModulePW
43
{
54

@@ -136,7 +135,7 @@ void FFT :: initplan()
136135
// 2 D
137136
//---------------------------------------------------------
138137

139-
int nrank[2] = {this->nx,this->bigny};
138+
//int nrank[2] = {this->nx,this->bigny};
140139
int *embed = NULL;
141140
if(this->gamma_only)
142141
{
@@ -301,80 +300,78 @@ void FFT:: cleanFFT()
301300
return;
302301
}
303302

304-
void FFT::executefftw(std::string instr)
303+
void FFT::fftzfor(std::complex<double>* & in, std::complex<double>* & out)
305304
{
306-
if(instr == "1for")
307-
{
308-
// for(int i = 0 ; i < this->ns ; ++i)
309-
// {
310-
// fftw_execute_dft(this->planzfor,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
311-
// }
312-
fftw_execute_dft(this->planzfor,(fftw_complex *)c_gspace,(fftw_complex *)c_gspace);
313-
// fftw_execute(this->planzfor);
314-
}
315-
else if(instr == "1bac")
316-
{
317-
// for(int i = 0 ; i < this->ns ; ++i)
318-
// {
319-
// fftw_execute_dft(this->planzbac,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
320-
// }
321-
fftw_execute_dft(this->planzbac,(fftw_complex *)c_gspace,(fftw_complex *)c_gspace);
322-
// fftw_execute(this->planzbac);
323-
}
324-
else if(instr == "2for")
305+
// for(int i = 0 ; i < this->ns ; ++i)
306+
// {
307+
// fftw_execute_dft(this->planzfor,(fftw_complex *)&in[i*nz],(fftw_complex *)&out[i*nz]);
308+
// }
309+
fftw_execute_dft(this->planzfor,(fftw_complex *)in,(fftw_complex *)out);
310+
return;
311+
}
312+
313+
void FFT::fftzbac(std::complex<double>* & in, std::complex<double>* & out)
314+
{
315+
// for(int i = 0 ; i < this->ns ; ++i)
316+
// {
317+
// fftw_execute_dft(this->planzbac,(fftw_complex *)&c_gspace[i*nz],(fftw_complex *)&c_gspace[i*nz]);
318+
// }
319+
fftw_execute_dft(this->planzbac,(fftw_complex *)in, (fftw_complex *)out);
320+
return;
321+
}
322+
323+
void FFT::fftxyfor(std::complex<double>* & in, std::complex<double>* & out)
324+
{
325+
int npy = this->nplane * this-> ny;
326+
fftw_execute_dft( this->planxfor, (fftw_complex *)in, (fftw_complex *)out);
327+
for (int i=0; i<this->nx;++i)
325328
{
326-
int npy = this->nplane * this-> ny;
327-
fftw_execute_dft( this->planxfor, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
328-
for (int i=0; i<this->nx;++i)
329-
{
330-
fftw_execute_dft( this->planyfor, (fftw_complex*)&c_rspace[i*npy], (fftw_complex*)&c_rspace[i*npy] );
331-
}
332-
// fftw_execute(this->plan2for);
329+
fftw_execute_dft( this->planyfor, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
333330
}
334-
else if(instr == "2bac")
335-
{
336-
// fftw_execute(this->plan2bac);
337-
int npy = this->nplane * this-> ny;
338-
339-
for (int i=0; i<this->nx;++i)
340-
{
341-
fftw_execute_dft( this->planybac, (fftw_complex*)&c_rspace[i*npy], (fftw_complex*)&c_rspace[i*npy] );
342-
}
343-
fftw_execute_dft( this->planxbac, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
331+
return;
332+
}
333+
334+
void FFT::fftxybac(std::complex<double>* & in, std::complex<double>* & out)
335+
{
336+
int npy = this->nplane * this-> ny;
344337

345-
}
346-
else if(instr == "2r2c")
338+
for (int i=0; i<this->nx;++i)
347339
{
348-
// fftw_execute(this->plan2r2c);
349-
//int npy = this->nplane * this-> ny;
350-
int bignpy = this->nplane * this-> bigny;
351-
// int padnpy = this->nplane * this-> ny * 2;
352-
for (int i=0; i<this->nx;++i)
353-
{
354-
fftw_execute_dft_r2c( this->planyr2c, &r_rspace[i*bignpy], (fftw_complex*)&c_rspace[i*bignpy] );
355-
// fftw_execute_dft_r2c( this->planyfor, &r_rspace[4*i*padnpy], (fftw_complex*)&c_rspace[i*padnpy] );
356-
}
357-
fftw_execute_dft( this->planxfor, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
340+
fftw_execute_dft( this->planybac, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
358341
}
359-
else if(instr == "2c2r")
342+
fftw_execute_dft( this->planxbac, (fftw_complex *)in, (fftw_complex *)out);
343+
return;
344+
}
345+
346+
void FFT::fftxyr2c(double* &in, std::complex<double>* & out)
347+
{
348+
//int npy = this->nplane * this-> ny;
349+
int bignpy = this->nplane * this-> bigny;
350+
// int padnpy = this->nplane * this-> ny * 2;
351+
for (int i=0; i<this->nx;++i)
360352
{
361-
// fftw_execute(this->plan2c2r);
362-
//int npy = this->nplane * this-> ny;
363-
int bignpy = this->nplane * this-> bigny;
364-
// int padnpy = this->nplane * this-> ny * 2;
365-
fftw_execute_dft( this->planxbac, (fftw_complex *)c_rspace, (fftw_complex *)c_rspace);
366-
for (int i=0; i<this->nx;++i)
367-
{
368-
fftw_execute_dft_c2r( this->planyc2r, (fftw_complex*)&c_rspace[i*bignpy], &r_rspace[i*bignpy] );
369-
// fftw_execute_dft_c2r( this->planybac, (fftw_complex*)&c_rspace[i*padnpy], &r_rspace[4*i*padnpy] );
370-
}
353+
fftw_execute_dft_r2c( this->planyr2c, &in[i*bignpy], (fftw_complex*)&out[i*bignpy] );
354+
// fftw_execute_dft_r2c( this->planyfor, &r_rspace[4*i*padnpy], (fftw_complex*)&c_rspace[i*padnpy] );
371355
}
372-
else
356+
fftw_execute_dft( this->planxfor, (fftw_complex *)out, (fftw_complex *)out);
357+
return;
358+
}
359+
360+
void FFT::fftxyc2r(std::complex<double>* & in, double* & out)
361+
{
362+
//int npy = this->nplane * this-> ny;
363+
int bignpy = this->nplane * this-> bigny;
364+
// int padnpy = this->nplane * this-> ny * 2;
365+
fftw_execute_dft( this->planxbac, (fftw_complex *)in, (fftw_complex *)in);
366+
for (int i=0; i<this->nx;++i)
373367
{
374-
ModuleBase::WARNING_QUIT("FFT", "Wrong input for excutefftw");
368+
fftw_execute_dft_c2r( this->planyc2r, (fftw_complex*)&in[i*bignpy], &out[i*bignpy] );
369+
// fftw_execute_dft_c2r( this->planybac, (fftw_complex*)&c_rspace[i*padnpy], &r_rspace[4*i*padnpy] );
375370
}
371+
return;
376372
}
377373

374+
378375
#ifdef __MIX_PRECISION
379376
void executefftwf(std::string instr)
380377
{
@@ -390,10 +387,6 @@ void executefftwf(std::string instr)
390387
fftwf_execute(this->planf2r2c);
391388
else if(instr == "2c2r")
392389
fftwf_execute(this->planf2c2r);
393-
else
394-
{
395-
ModuleBase::WARNING_QUIT("FFT", "Wrong input for excutefftwf");
396-
}
397390
}
398391
#endif
399392
}

source/module_pw/fft.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ class FFT
3131
void setupFFT();
3232
void cleanFFT();
3333

34-
void executefftw(std::string instr);
34+
void fftzfor(std::complex<double>* & in, std::complex<double>* & out);
35+
void fftzbac(std::complex<double>* & in, std::complex<double>* & out);
36+
void fftxyfor(std::complex<double>* & in, std::complex<double>* & out);
37+
void fftxybac(std::complex<double>* & in, std::complex<double>* & out);
38+
void fftxyr2c(double * &in, std::complex<double>* & out);
39+
void fftxyc2r(std::complex<double>* & in, double* & out);
40+
3541
#ifdef __MIX_PRECISION
3642
void executefftwf(std::string instr);
3743
#endif

source/module_pw/pw_distributeg_method2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void PW_Basis::distribution_method2()
2727
// initial the variables needed by all proc.
2828
int tot_npw = 0; // total number of planewaves.
2929
this->nstot = 0; // total number of sticks.
30-
int st_start = 0; // index of the first stick on current proc.
30+
// int st_start = 0; // index of the first stick on current proc.
3131
int *st_bottom2D = NULL; // st_bottom2D[ixy], minimum z of stick on (x, y).
3232
int *st_length2D = NULL; // st_length2D[ixy], number of planewaves in stick on (x, y).
3333

source/module_pw/pw_transform.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ void PW_Basis:: real2recip(std::complex<double> * in, std::complex<double> * out
1919
{
2020
this->ft.c_rspace[ir] = in[ir];
2121
}
22-
this->ft.executefftw("2for");
22+
this->ft.fftxyfor(ft.c_rspace,ft.c_rspace);
2323

2424
this->gatherp_scatters2(this->ft.c_rspace, this->ft.c_gspace);
2525

26-
this->ft.executefftw("1for");
26+
this->ft.fftzfor(ft.c_gspace,ft.c_gspace);
2727

2828
for(int ig = 0 ; ig < this->npw ; ++ig)
2929
{
@@ -58,11 +58,11 @@ void PW_Basis:: real2recip(double * in, std::complex<double> * out)
5858
// }
5959
// }
6060

61-
this->ft.executefftw("2r2c");
61+
this->ft.fftxyr2c(ft.r_rspace,ft.c_rspace);
6262

6363
this->gatherp_scatters2(this->ft.c_rspace, this->ft.c_gspace);
6464

65-
this->ft.executefftw("1for");
65+
this->ft.fftzfor(ft.c_gspace,ft.c_gspace);
6666

6767
for(int ig = 0 ; ig < this->npw ; ++ig)
6868
{
@@ -85,11 +85,11 @@ void PW_Basis:: recip2real(std::complex<double> * in, std::complex<double> * out
8585
{
8686
this->ft.c_gspace[this->ig2isz[ig]] = in[ig];
8787
}
88-
this->ft.executefftw("1bac");
88+
this->ft.fftzbac(ft.c_gspace, ft.c_gspace);
8989

9090
this->gathers_scatterp2(this->ft.c_gspace,this->ft.c_rspace);
9191

92-
this->ft.executefftw("2bac");
92+
this->ft.fftxybac(ft.c_rspace,ft.c_rspace);
9393

9494
for(int ir = 0 ; ir < this->nrxx ; ++ir)
9595
{
@@ -113,11 +113,11 @@ void PW_Basis:: recip2real(std::complex<double> * in, double * out)
113113
{
114114
this->ft.c_gspace[this->ig2isz[ig]] = in[ig];
115115
}
116-
this->ft.executefftw("1bac");
116+
this->ft.fftzbac(ft.c_gspace, ft.c_gspace);
117117

118118
this->gathers_scatterp2(this->ft.c_gspace, this->ft.c_rspace);
119119

120-
this->ft.executefftw("2c2r");
120+
this->ft.fftxyc2r(ft.c_rspace,ft.r_rspace);
121121

122122
for(int ir = 0 ; ir < this->nrxx ; ++ir)
123123
{

source/module_pw/unittest/test_t1.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,7 @@ int main(int argc,char **argv)
5151
nx = pwtest.nx;
5252
ny = pwtest.ny;
5353
nz = pwtest.nz;
54-
int nplane = pwtest.nplane;
55-
int nxyz = nx * ny * nz;
5654
if(myrank == 0) cout<<"FFT: "<<nx<<" "<<ny<<" "<<nz<<endl;
57-
double tpiba2 = ModuleBase::TWO_PI * ModuleBase::TWO_PI / lat0 / lat0;
58-
double ggecut = wfcecut / tpiba2;
5955
ModuleBase::Matrix3 GT,G,GGT;
6056
GT = latvec.Inverse();
6157
G = GT.Transpose();

source/module_pw/unittest/test_t2.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,7 @@ int main(int argc,char **argv)
5151
nx = pwtest.nx;
5252
ny = pwtest.ny;
5353
nz = pwtest.nz;
54-
int nplane = pwtest.nplane;
55-
int nxyz = nx * ny * nz;
5654
if(myrank == 0) cout<<"FFT: "<<nx<<" "<<ny<<" "<<nz<<endl;
57-
double tpiba2 = ModuleBase::TWO_PI * ModuleBase::TWO_PI / lat0 / lat0;
58-
double ggecut = wfcecut / tpiba2;
5955
ModuleBase::Matrix3 GT,G,GGT;
6056
GT = latvec.Inverse();
6157
G = GT.Transpose();

0 commit comments

Comments
 (0)