Skip to content

Commit 3e3e9bc

Browse files
committed
Merge branch 'planewave' of https://github.com/Qianruipku/abacus-develop into planewave
2 parents 8edadcc + 7d005bd commit 3e3e9bc

File tree

18 files changed

+715
-156
lines changed

18 files changed

+715
-156
lines changed

source/module_pw/fft.cpp

Lines changed: 132 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ FFT::FFT()
1515
r_rspace = NULL;
1616
#ifdef __MIX_PRECISION
1717
destroypf = true;
18-
cf_rspace = cf_gspace = NULL;
18+
auxf2 = auxf1 = NULL;
1919
rf_rspace = NULL;
2020
#endif
2121
}
@@ -26,9 +26,9 @@ FFT::~FFT()
2626
if(aux1!=NULL) fftw_free(aux1);
2727
if(aux2!=NULL) fftw_free(aux2);
2828
#ifdef __MIX_PRECISION
29-
if(cf_gspace!=NULL) fftw_free(cf_gspace);
30-
if(cf_rspace!=NULL) fftw_free(cf_rspace);
31-
if(rf_rspace!=NULL) fftw_free(rf_rspace);
29+
this->cleanfFFT();
30+
if(auxf1!=NULL) fftw_free(auxf1);
31+
if(auxf2!=NULL) fftw_free(auxf2);
3232
#endif
3333
}
3434

@@ -55,12 +55,9 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int liy_in, int riy_in, i
5555
aux2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
5656
r_rspace = (double *) aux1;
5757
#ifdef __MIX_PRECISION
58-
cf_gspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->nz * this->ns);
59-
cf_rspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->bignxy * nplane);
60-
if(this->gamma_only)
61-
{
62-
rf_rspace = (float *)fftw_malloc(sizeof(float) * this->bignxy * nplane);
63-
}
58+
auxf1 = (std::complex<float> *) fftw_malloc(sizeof(fftwf_complex) * maxgrids);
59+
auxf2 = (std::complex<float> *) fftw_malloc(sizeof(fftwf_complex) * maxgrids);
60+
rf_rspace = (float *) auxf1;
6461
#endif
6562
}
6663
else
@@ -190,44 +187,41 @@ void FFT :: initplanf()
190187
// 1 D
191188
//---------------------------------------------------------
192189

193-
// fftwf_plan_many_dft(int rank, const int *n, int howmany,
190+
// fftw_plan_many_dft(int rank, const int *n, int howmany,
194191
// fftw_complex *in, const int *inembed, int istride, int idist,
195192
// fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned flags);
196193

197-
this->planf1for = fftwf_plan_many_dft( 1, &this->nz, this->ns,
198-
(fftwf_complex*)aux1, &this->nz, 1, this->nz,
199-
(fftwf_complex*)aux1, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
200-
201-
this->planf1bac = fftwf_plan_many_dft( 1, &this->nz, this->ns,
202-
(fftwf_complex*)aux1, &this->nz, 1, this->nz,
203-
(fftwf_complex*)aux1, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
194+
//It is better to use out-of-place fft for stride = 1.
195+
this->planfzfor = fftwf_plan_many_dft( 1, &this->nz, this->ns,
196+
(fftwf_complex*) auxf1, &this->nz, 1, this->nz,
197+
(fftwf_complex*) auxf2, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
204198

205-
199+
this->planfzbac = fftwf_plan_many_dft( 1, &this->nz, this->ns,
200+
(fftwf_complex*) auxf1, &this->nz, 1, this->nz,
201+
(fftwf_complex*) auxf2, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
206202
//---------------------------------------------------------
207203
// 2 D
208204
//---------------------------------------------------------
209205

210-
int nrank[2] = {this->nx,this->bigny};
211-
206+
int *embed = NULL;
207+
int bignpy = this->nplane * this->bigny;
208+
this->planfxfor = fftwf_plan_many_dft( 1, &this->nx, this->nplane, (fftwf_complex *)auxf2, embed, bignpy, 1,
209+
(fftwf_complex *)auxf2, embed, bignpy, 1, FFTW_FORWARD, FFTW_MEASURE );
210+
this->planfxbac = fftwf_plan_many_dft( 1, &this->nx, this->nplane, (fftwf_complex *)auxf2, embed, bignpy, 1,
211+
(fftwf_complex *)auxf2, embed, bignpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
212212
if(this->gamma_only)
213213
{
214-
this->planf2r2c = fftwf_plan_many_dft_r2c( 2, nrank, this->nplane,
215-
r_rspace, nrank, this->nplane, 1,
216-
(fftwf_complex*)aux2, nrank, this->nplane, 1, FFTW_MEASURE);
217-
218-
this->planf2c2r = fftwf_plan_many_dft_c2r( 2, nrank, this->nplane,
219-
(fftwf_complex*)aux2, nrank, this->nplane, 1,
220-
r_rspace, nrank, this->nplane, 1, FFTW_MEASURE);
214+
this->planfyr2c = fftwf_plan_many_dft_r2c( 1, &this->bigny, this->nplane, rf_rspace , embed, this->nplane, 1,
215+
(fftwf_complex*)auxf1, embed, this->nplane, 1, FFTW_MEASURE );
216+
this->planfyc2r = fftwf_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftwf_complex*)auxf1 , embed, this->nplane, 1,
217+
rf_rspace, embed, this->nplane, 1, FFTW_MEASURE );
221218
}
222219
else
223220
{
224-
this->planf2for = fftwf_plan_many_dft( 2, nrank, this->nplane,
225-
(fftwf_complex*)aux2, nrank, this->nplane, 1,
226-
(fftwf_complex*)aux2, nrank, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
227-
228-
this->planf2bac = fftwf_plan_many_dft( 2, nrank, this->nplane,
229-
(fftwf_complex*)aux2, nrank, this->nplane, 1,
230-
(fftwf_complex*)aux2, nrank, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
221+
this->planfyfor = fftwf_plan_many_dft( 1, &this->bigny, this->nplane, (fftwf_complex*)auxf2 , embed, this->nplane, 1,
222+
(fftwf_complex*)auxf2, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE );
223+
this->planfybac = fftwf_plan_many_dft( 1, &this->bigny, this->nplane, (fftwf_complex*)auxf2 , embed, this->nplane, 1,
224+
(fftwf_complex*)auxf2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE );
231225
}
232226
destroypf = false;
233227
}
@@ -263,26 +257,31 @@ void FFT:: cleanFFT()
263257
fftw_destroy_plan(planybac);
264258
}
265259
destroyp = true;
266-
260+
return;
261+
}
267262
#ifdef __MIX_PRECISION
263+
void FFT:: cleanfFFT()
264+
{
268265
if(destroypf==true) return;
269-
fftw_destroy_plan(planf1for);
270-
fftw_destroy_plan(planf1bac);
266+
fftwf_destroy_plan(planfzfor);
267+
fftwf_destroy_plan(planfzbac);
268+
fftwf_destroy_plan(planfxfor);
269+
fftwf_destroy_plan(planfxbac);
271270
if(this->gamma_only)
272271
{
273-
fftw_destroy_plan(planf2r2c);
274-
fftw_destroy_plan(planf2c2r);
272+
fftwf_destroy_plan(planfyr2c);
273+
fftwf_destroy_plan(planfyc2r);
275274
}
276275
else
277276
{
278-
fftw_destroy_plan(planf2for);
279-
fftw_destroy_plan(planf2bac);
277+
fftwf_destroy_plan(planfyfor);
278+
fftwf_destroy_plan(planfybac);
280279
}
281280
destroypf = true;
282-
#endif
283-
284281
return;
285282
}
283+
#endif
284+
286285

287286
void FFT::fftzfor(std::complex<double>* & in, std::complex<double>* & out)
288287
{
@@ -386,20 +385,96 @@ void FFT::fftxyc2r(std::complex<double>* & in, double* & out)
386385

387386

388387
#ifdef __MIX_PRECISION
389-
void executefftwf(std::string instr)
388+
void FFT::fftfzfor(std::complex<float>* & in, std::complex<float>* & out)
390389
{
391-
if(instr == "1for")
392-
fftwf_execute(this->planf1for);
393-
else if(instr == "2for")
394-
fftwf_execute(this->planf2for);
395-
else if(instr == "1bac")
396-
fftwf_execute(this->planf1bac);
397-
else if(instr == "2bac")
398-
fftwf_execute(this->planf2bac);
399-
else if(instr == "2r2c")
400-
fftwf_execute(this->planf2r2c);
401-
else if(instr == "2c2r")
402-
fftwf_execute(this->planf2c2r);
390+
fftwf_execute_dft(this->planfzfor,(fftwf_complex *)in,(fftwf_complex *)out);
391+
return;
392+
}
393+
394+
void FFT::fftfzbac(std::complex<float>* & in, std::complex<float>* & out)
395+
{
396+
fftwf_execute_dft(this->planfzbac,(fftwf_complex *)in, (fftwf_complex *)out);
397+
return;
398+
}
399+
400+
void FFT::fftfxyfor(std::complex<float>* & in, std::complex<float>* & out)
401+
{
402+
int bignpy = this->nplane * this-> bigny;
403+
for (int i=0; i<this->nx;++i)
404+
{
405+
fftwf_execute_dft( this->planfyfor, (fftwf_complex *)&in[i*bignpy], (fftwf_complex *)&out[i*bignpy]);
406+
}
407+
408+
for (int i=0; i<=this->liy;++i)
409+
{
410+
fftwf_execute_dft( this->planfxfor, (fftwf_complex *)&in[i*nplane], (fftwf_complex *)&out[i*nplane]);
411+
}
412+
for (int i=this->riy; i<this->ny;++i)
413+
{
414+
fftwf_execute_dft( this->planfxfor, (fftwf_complex *)&in[i*nplane], (fftwf_complex *)&out[i*nplane]);
415+
}
416+
return;
417+
}
418+
419+
void FFT::fftfxybac(std::complex<float>* & in, std::complex<float>* & out)
420+
{
421+
int bignpy = this->nplane * this-> bigny;
422+
//x-direction
423+
for (int i=0; i<=this->liy;++i)
424+
{
425+
fftwf_execute_dft( this->planfxbac, (fftwf_complex *)&in[i*nplane], (fftwf_complex *)&out[i*nplane]);
426+
}
427+
for (int i=this->riy; i<this->ny;++i)
428+
{
429+
fftwf_execute_dft( this->planfxbac, (fftwf_complex *)&in[i*nplane], (fftwf_complex *)&out[i*nplane]);
430+
}
431+
432+
////y-direction
433+
for (int i=0; i<this->nx;++i)
434+
{
435+
fftwf_execute_dft( this->planfybac, (fftwf_complex*)&in[i*bignpy], (fftwf_complex*)&out[i*bignpy] );
436+
}
437+
return;
438+
}
439+
440+
void FFT::fftfxyr2c(float* &in, std::complex<float>* & out)
441+
{
442+
int bignpy = this->nplane * this-> bigny;
443+
444+
for (int i=0; i<this->nx;++i)
445+
{
446+
fftwf_execute_dft_r2c( this->planfyr2c, &in[i*bignpy*2], (fftwf_complex*)&out[i*bignpy] );
447+
}
448+
449+
for (int i=0; i<=this->liy;++i)
450+
{
451+
fftwf_execute_dft( this->planfxfor, (fftwf_complex *)&out[i*nplane], (fftwf_complex *)&out[i*nplane]);
452+
}
453+
for (int i=this->riy; i<this->ny;++i)
454+
{
455+
fftwf_execute_dft( this->planfxfor, (fftwf_complex *)&out[i*nplane], (fftwf_complex *)&out[i*nplane]);
456+
}
457+
return;
458+
}
459+
460+
461+
void FFT::fftfxyc2r(std::complex<float>* & in, float* & out)
462+
{
463+
int bignpy = this->nplane * this-> bigny;
464+
for (int i=0; i<=this->liy;++i)
465+
{
466+
fftwf_execute_dft( this->planfxbac, (fftwf_complex *)&in[i*nplane], (fftwf_complex *)&in[i*nplane]);
467+
}
468+
for (int i=this->riy; i<this->ny;++i)
469+
{
470+
fftwf_execute_dft( this->planfxbac, (fftwf_complex *)&in[i*nplane], (fftwf_complex *)&in[i*nplane]);
471+
}
472+
473+
for (int i=0; i<this->nx;++i)
474+
{
475+
fftwf_execute_dft_c2r( this->planfyc2r, (fftwf_complex*)&in[i*bignpy], &out[i*bignpy*2] );
476+
}
477+
return;
403478
}
404479
#endif
405480
}

source/module_pw/fft.h

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
//#include "fftw3-mpi_mkl.h"
1111
#endif
1212

13-
#ifdef __MIX_PRECISION
14-
#include "fftw3f.h"
15-
#if defined(__FFTW3_MPI) && defined(__MPI)
16-
#include "fftw3f-mpi.h"
17-
//#include "fftw3-mpi_mkl.h"
18-
#endif
19-
#endif
13+
// #ifdef __MIX_PRECISION
14+
// #include "fftw3f.h"
15+
// #if defined(__FFTW3_MPI) && defined(__MPI)
16+
// #include "fftw3f-mpi.h"
17+
// //#include "fftw3-mpi_mkl.h"
18+
// #endif
19+
// #endif
2020

2121
namespace ModulePW
2222
{
@@ -40,10 +40,16 @@ class FFT
4040
void fftxyc2r(std::complex<double>* & in, double* & out);
4141

4242
#ifdef __MIX_PRECISION
43-
void executefftwf(std::string instr);
43+
void cleanfFFT();
44+
void fftfzfor(std::complex<float>* & in, std::complex<float>* & out);
45+
void fftfzbac(std::complex<float>* & in, std::complex<float>* & out);
46+
void fftfxyfor(std::complex<float>* & in, std::complex<float>* & out);
47+
void fftfxybac(std::complex<float>* & in, std::complex<float>* & out);
48+
void fftfxyr2c(float * &in, std::complex<float>* & out);
49+
void fftfxyc2r(std::complex<float>* & in, float* & out);
4450
#endif
4551

46-
private:
52+
public:
4753
void initplan();
4854
void initplan_mpi();
4955
#ifdef __MIX_PRECISION
@@ -64,8 +70,7 @@ class FFT
6470
std::complex<double> *aux1, *aux2; //fft space, [maxgrids]
6571
double *r_rspace; //real number space for r, [nplane * nx *ny]
6672
#ifdef __MIX_PRECISION
67-
std::complex<float> * cf_gspace; //complex number space for g, [ns * nz]
68-
std::complex<float> * cf_rspace; //complex number space for r, [nplane * nx *ny]
73+
std::complex<float> *auxf1, *auxf2; //fft space, [maxgrids]
6974
float *rf_rspace; //real number space for r, [nplane * nx *ny]
7075
#endif
7176

@@ -90,12 +95,14 @@ class FFT
9095
fftw_plan planyc2r;
9196
#ifdef __MIX_PRECISION
9297
bool destroypf;
93-
fftwf_plan planf2r2c;
94-
fftwf_plan planf2c2r;
95-
fftwf_plan planf1for;
96-
fftwf_plan planf1bac;
97-
fftwf_plan planf2for;
98-
fftwf_plan planf2bac;
98+
fftwf_plan planfzfor;
99+
fftwf_plan planfzbac;
100+
fftwf_plan planfxfor;
101+
fftwf_plan planfxbac;
102+
fftwf_plan planfyfor;
103+
fftwf_plan planfybac;
104+
fftwf_plan planfyr2c;
105+
fftwf_plan planfyc2r;
99106
#endif
100107

101108
};

source/module_pw/pw_basis.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,20 @@ class PW_Basis
193193
void recip2real(std::complex<double> * in, double *out); //in:(nz, ns) ; out(nplane,nx*ny)
194194
void recip2real(std::complex<double> * in, std::complex<double> * out); //in:(nz, ns) ; out(nplane,nx*ny)
195195

196-
void gatherp_scatters(std::complex<double> *in, std::complex<double> *out); //gather planes and scatter sticks of all processors
197-
void gathers_scatterp(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
196+
#ifdef __MIX_PRECISION
197+
void real2recip(float * in, std::complex<float> * out); //in:(nplane,nx*ny) ; out(nz, ns)
198+
void real2recip(std::complex<float> * in, std::complex<float> * out); //in:(nplane,nx*ny) ; out(nz, ns)
199+
void recip2real(std::complex<float> * in, float *out); //in:(nz, ns) ; out(nplane,nx*ny)
200+
void recip2real(std::complex<float> * in, std::complex<float> * out); //in:(nz, ns) ; out(nplane,nx*ny)
201+
#endif
202+
template<typename T>
203+
void gatherp_scatters(std::complex<T> *in, std::complex<T> *out); //gather planes and scatter sticks of all processors
204+
template<typename T>
205+
void gathers_scatterp(std::complex<T> *in, std::complex<T> *out); //gather sticks of and scatter planes of all processors
198206
// void gathers_scatterp2(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
199207
// void gatherp_scatters2(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors
200-
void gatherp_scatters_gamma(std::complex<double> *in, std::complex<double> *out); //gather planes and scatter sticks of all processors, used when gamma_only
201-
void gathers_scatterp_gamma(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors, used when gamma only
208+
// void gatherp_scatters_gamma(std::complex<double> *in, std::complex<double> *out); //gather planes and scatter sticks of all processors, used when gamma_only
209+
// void gathers_scatterp_gamma(std::complex<double> *in, std::complex<double> *out); //gather sticks of and scatter planes of all processors, used when gamma only
202210

203211

204212

0 commit comments

Comments
 (0)