Skip to content

Commit e868f9c

Browse files
committed
Merge branch 'planewave' of https://github.com/Qianruipku/abacus-develop into planewave
2 parents a9f10b8 + adce797 commit e868f9c

17 files changed

+663
-293
lines changed

source/module_base/global_function.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ template<class T, class TI>
117117
inline void ZEROS(std::complex<T> *u,const TI n) // Peize Lin change int to TI at 2020.03.03
118118
{
119119
assert(n>=0);
120-
assert(u!=0);
121120
for (TI i=0;i<n;i++)
122121
{
123122
u[i] = std::complex<T>(0.0,0.0);

source/module_pw/fft.cpp

Lines changed: 189 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "fft.h"
2-
#include "../module_base/tool_quit.h"
32
namespace ModulePW
43
{
54

@@ -12,8 +11,7 @@ FFT::FFT()
1211
mpifft = false;
1312
destroyp = true;
1413
gamma_only = false;
15-
c_rspace = c_gspace = NULL;
16-
c_rspace2 = c_gspace2 = NULL;
14+
aux2 = aux1 = NULL;
1715
r_rspace = NULL;
1816
#ifdef __MIX_PRECISION
1917
destroypf = true;
@@ -25,19 +23,16 @@ FFT::FFT()
2523
FFT::~FFT()
2624
{
2725
this->cleanFFT();
28-
if(c_gspace!=NULL) fftw_free(c_gspace);
29-
if(c_rspace!=NULL) fftw_free(c_rspace);
30-
if(c_gspace2!=NULL) fftw_free(c_gspace2);
31-
if(c_rspace2!=NULL) fftw_free(c_rspace2);
32-
if(r_rspace!=NULL) fftw_free(r_rspace);
26+
if(aux1!=NULL) fftw_free(aux1);
27+
if(aux2!=NULL) fftw_free(aux2);
3328
#ifdef __MIX_PRECISION
3429
if(cf_gspace!=NULL) fftw_free(cf_gspace);
3530
if(cf_rspace!=NULL) fftw_free(cf_rspace);
3631
if(rf_rspace!=NULL) fftw_free(rf_rspace);
3732
#endif
3833
}
3934

40-
void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in, bool gamma_only_in, bool mpifft_in)
35+
void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int liy_in, int riy_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool mpifft_in)
4136
{
4237
this->gamma_only = gamma_only_in;
4338
this->nx = nx_in;
@@ -46,24 +41,19 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int ns_in, int nplane_in,
4641
else this->ny = this->bigny;
4742
this->nz = nz_in;
4843
this->ns = ns_in;
44+
this->liy = liy_in;
45+
this->riy = riy_in;
4946
this->nplane = nplane_in;
47+
this->nproc = nproc_in;
5048
this->mpifft = mpifft_in;
5149
this->nxy = this->nx * this-> ny;
5250
this->bignxy = this->nx * this->bigny;
51+
this->maxgrids = (this->nz * this->ns > this->bignxy * nplane) ? this->nz * this->ns : this->bignxy * nplane;
5352
if(!this->mpifft)
5453
{
55-
//out-of-place fft is faster than in-place fft
56-
c_gspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
57-
c_gspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->nz * this->ns);
58-
c_rspace = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
59-
if(this->gamma_only)
60-
{
61-
r_rspace = (double *) fftw_malloc(sizeof(double) * this->bignxy * nplane);
62-
}
63-
else
64-
{
65-
c_rspace2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * this->bignxy * nplane);
66-
}
54+
aux1 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
55+
aux2 = (std::complex<double> *) fftw_malloc(sizeof(fftw_complex) * maxgrids);
56+
r_rspace = (double *) aux1;
6757
#ifdef __MIX_PRECISION
6858
cf_gspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->nz * this->ns);
6959
cf_rspace = (std::complex<float> *)fftw_malloc(sizeof(fftwf_complex) * this->bignxy * nplane);
@@ -111,40 +101,85 @@ void FFT :: initplan()
111101
// fftw_complex *in, const int *inembed, int istride, int idist,
112102
// fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned flags);
113103

114-
this->plan1for = fftw_plan_many_dft( 1, &this->nz, this->ns,
115-
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
116-
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
104+
//It is better to use out-of-place fft for stride = 1.
105+
this->planzfor = fftw_plan_many_dft( 1, &this->nz, this->ns,
106+
(fftw_complex*) aux1, &this->nz, 1, this->nz,
107+
(fftw_complex*) aux2, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
117108

118-
this->plan1bac = fftw_plan_many_dft( 1, &this->nz, this->ns,
119-
(fftw_complex*) c_gspace, &this->nz, 1, this->nz,
120-
(fftw_complex*) c_gspace2, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
109+
this->planzbac = fftw_plan_many_dft( 1, &this->nz, this->ns,
110+
(fftw_complex*) aux1, &this->nz, 1, this->nz,
111+
(fftw_complex*) aux2, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
112+
113+
// this->planzfor = fftw_plan_dft_1d(this->nz,(fftw_complex*) aux1,(fftw_complex*) aux1, FFTW_FORWARD, FFTW_MEASURE);
114+
// this->planzbac = fftw_plan_dft_1d(this->nz,(fftw_complex*) aux1,(fftw_complex*) aux1,FFTW_BACKWARD, FFTW_MEASURE);
121115

122116
//---------------------------------------------------------
123117
// 2 D
124118
//---------------------------------------------------------
125119

126-
int nrank[2] = {this->nx,this->bigny};
120+
//int nrank[2] = {this->nx,this->bigny};
127121
int *embed = NULL;
122+
// It seems 1D+1D is much faster than 2D FFT!
128123
if(this->gamma_only)
129124
{
130-
this->plan2r2c = fftw_plan_many_dft_r2c( 2, nrank, this->nplane,
131-
r_rspace, embed, this->nplane, 1,
132-
(fftw_complex*) c_rspace, embed, this->nplane, 1, FFTW_MEASURE);
125+
// int padnpy = this->nplane * this->ny * 2;
126+
// int rankc[2] = {this->nx, this->padnpy};
127+
// int rankd[2] = {this->nx, this->padnpy*2};
128+
// // It seems 1D+1D is much faster than 2D FFT!
129+
// this->plan2r2c = fftw_plan_many_dft_r2c( 2, nrank, this->nplane,
130+
// r_rspace, rankd, this->nplane, 1,
131+
// (fftw_complex*) aux2, rankc, this->nplane, 1, FFTW_MEASURE);
132+
133+
// this->plan2c2r = fftw_plan_many_dft_c2r( 2, nrank, this->nplane,
134+
// (fftw_complex*) aux2, rankc, this->nplane, 1,
135+
// r_rspace, rankd, this->nplane, 1, FFTW_MEASURE);
136+
137+
// int npy = this->nplane * this->ny;
138+
int bignpy = this->nplane * this->bigny;
139+
this->planxfor = fftw_plan_many_dft( 1, &this->nx, this->nplane, (fftw_complex *)aux2, embed, bignpy, 1,
140+
(fftw_complex *)aux2, embed, bignpy, 1, FFTW_FORWARD, FFTW_MEASURE );
141+
this->planxbac = fftw_plan_many_dft( 1, &this->nx, this->nplane, (fftw_complex *)aux2, embed, bignpy, 1,
142+
(fftw_complex *)aux2, embed, bignpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
143+
this->planyr2c = fftw_plan_many_dft_r2c( 1, &this->bigny, this->nplane, r_rspace , embed, this->nplane, 1,
144+
(fftw_complex*)aux1, embed, this->nplane, 1, FFTW_MEASURE );
145+
this->planyc2r = fftw_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftw_complex*)aux1 , embed, this->nplane, 1,
146+
r_rspace, embed, this->nplane, 1, FFTW_MEASURE );
147+
148+
// int padnpy = npy * 2;
149+
// this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, padnpy, 1,
150+
// (fftw_complex *)aux2, embed, padnpy, 1, FFTW_FORWARD, FFTW_MEASURE );
151+
// this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, padnpy, 1,
152+
// (fftw_complex *)aux2, embed, padnpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
153+
// this->planyr2c = fftw_plan_many_dft_r2c( 1, &this->bigny, this->nplane, r_rspace , embed, this->nplane*2, 1,
154+
// (fftw_complex*)aux2, embed, this->nplane, 1, FFTW_MEASURE );
155+
// this->planyc2r = fftw_plan_many_dft_c2r( 1, &this->bigny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
156+
// r_rspace, embed, this->nplane*2, 1, FFTW_MEASURE );
133157

134-
this->plan2c2r = fftw_plan_many_dft_c2r( 2, nrank, this->nplane,
135-
(fftw_complex*) c_rspace, embed, this->nplane, 1,
136-
r_rspace, embed, this->nplane, 1, FFTW_MEASURE);
137158
}
138159
else
139160
{
140-
this->plan2for = fftw_plan_many_dft( 2, nrank, this->nplane,
141-
(fftw_complex*) c_rspace, embed, this->nplane, 1,
142-
(fftw_complex*) c_rspace2, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
161+
// this->plan2for = fftw_plan_many_dft( 2, nrank, this->nplane,
162+
// (fftw_complex*) aux2, embed, this->nplane, 1,
163+
// (fftw_complex*) aux2, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
143164

144-
this->plan2bac = fftw_plan_many_dft( 2, nrank, this->nplane,
145-
(fftw_complex*) c_rspace, embed, this->nplane, 1,
146-
(fftw_complex*) c_rspace2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
165+
// this->plan2bac = fftw_plan_many_dft( 2, nrank, this->nplane,
166+
// (fftw_complex*) aux2, embed, this->nplane, 1,
167+
// (fftw_complex*) aux2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
168+
169+
//It is better to use in-place for stride > 1
170+
int bignpy = this->nplane * this->bigny;
171+
this->planxfor = fftw_plan_many_dft( 1, &this->nx, this->nplane, (fftw_complex *)aux2, embed, bignpy, 1,
172+
(fftw_complex *)aux2, embed, bignpy, 1, FFTW_FORWARD, FFTW_MEASURE );
173+
this->planxbac = fftw_plan_many_dft( 1, &this->nx, this->nplane, (fftw_complex *)aux2, embed, bignpy, 1,
174+
(fftw_complex *)aux2, embed, bignpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
175+
this->planyfor = fftw_plan_many_dft( 1, &this->bigny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
176+
(fftw_complex*)aux2, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE );
177+
this->planybac = fftw_plan_many_dft( 1, &this->bigny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
178+
(fftw_complex*)aux2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE );
147179
}
180+
181+
182+
148183
destroyp = false;
149184
}
150185

@@ -160,12 +195,12 @@ void FFT :: initplanf()
160195
// fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned flags);
161196

162197
this->planf1for = fftwf_plan_many_dft( 1, &this->nz, this->ns,
163-
(fftwf_complex*)c_gspace, &this->nz, 1, this->nz,
164-
(fftwf_complex*)c_gspace, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
198+
(fftwf_complex*)aux1, &this->nz, 1, this->nz,
199+
(fftwf_complex*)aux1, &this->nz, 1, this->nz, FFTW_FORWARD, FFTW_MEASURE);
165200

166201
this->planf1bac = fftwf_plan_many_dft( 1, &this->nz, this->ns,
167-
(fftwf_complex*)c_gspace, &this->nz, 1, this->nz,
168-
(fftwf_complex*)c_gspace, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
202+
(fftwf_complex*)aux1, &this->nz, 1, this->nz,
203+
(fftwf_complex*)aux1, &this->nz, 1, this->nz, FFTW_BACKWARD, FFTW_MEASURE);
169204

170205

171206
//---------------------------------------------------------
@@ -178,21 +213,21 @@ void FFT :: initplanf()
178213
{
179214
this->planf2r2c = fftwf_plan_many_dft_r2c( 2, nrank, this->nplane,
180215
r_rspace, nrank, this->nplane, 1,
181-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1, FFTW_MEASURE);
216+
(fftwf_complex*)aux2, nrank, this->nplane, 1, FFTW_MEASURE);
182217

183218
this->planf2c2r = fftwf_plan_many_dft_c2r( 2, nrank, this->nplane,
184-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1,
219+
(fftwf_complex*)aux2, nrank, this->nplane, 1,
185220
r_rspace, nrank, this->nplane, 1, FFTW_MEASURE);
186221
}
187222
else
188223
{
189224
this->planf2for = fftwf_plan_many_dft( 2, nrank, this->nplane,
190-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1,
191-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
225+
(fftwf_complex*)aux2, nrank, this->nplane, 1,
226+
(fftwf_complex*)aux2, nrank, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE);
192227

193228
this->planf2bac = fftwf_plan_many_dft( 2, nrank, this->nplane,
194-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1,
195-
(fftwf_complex*)c_rspace, nrank, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
229+
(fftwf_complex*)aux2, nrank, this->nplane, 1,
230+
(fftwf_complex*)aux2, nrank, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
196231
}
197232
destroypf = false;
198233
}
@@ -213,17 +248,19 @@ void FFT :: initplanf_mpi()
213248
void FFT:: cleanFFT()
214249
{
215250
if(destroyp==true) return;
216-
fftw_destroy_plan(plan1for);
217-
fftw_destroy_plan(plan1bac);
251+
fftw_destroy_plan(planzfor);
252+
fftw_destroy_plan(planzbac);
253+
fftw_destroy_plan(planxfor);
254+
fftw_destroy_plan(planxbac);
218255
if(this->gamma_only)
219256
{
220-
fftw_destroy_plan(plan2r2c);
221-
fftw_destroy_plan(plan2c2r);
257+
fftw_destroy_plan(planyr2c);
258+
fftw_destroy_plan(planyc2r);
222259
}
223260
else
224261
{
225-
fftw_destroy_plan(plan2for);
226-
fftw_destroy_plan(plan2bac);
262+
fftw_destroy_plan(planyfor);
263+
fftw_destroy_plan(planybac);
227264
}
228265
destroyp = true;
229266

@@ -247,26 +284,107 @@ void FFT:: cleanFFT()
247284
return;
248285
}
249286

250-
void FFT::executefftw(std::string instr)
287+
void FFT::fftzfor(std::complex<double>* & in, std::complex<double>* & out)
251288
{
252-
if(instr == "1for")
253-
fftw_execute(this->plan1for);
254-
else if(instr == "2for")
255-
fftw_execute(this->plan2for);
256-
else if(instr == "1bac")
257-
fftw_execute(this->plan1bac);
258-
else if(instr == "2bac")
259-
fftw_execute(this->plan2bac);
260-
else if(instr == "2r2c")
261-
fftw_execute(this->plan2r2c);
262-
else if(instr == "2c2r")
263-
fftw_execute(this->plan2c2r);
264-
else
289+
// for(int i = 0 ; i < this->ns ; ++i)
290+
// {
291+
// fftw_execute_dft(this->planzfor,(fftw_complex *)&in[i*nz],(fftw_complex *)&out[i*nz]);
292+
// }
293+
fftw_execute_dft(this->planzfor,(fftw_complex *)in,(fftw_complex *)out);
294+
return;
295+
}
296+
297+
void FFT::fftzbac(std::complex<double>* & in, std::complex<double>* & out)
298+
{
299+
// for(int i = 0 ; i < this->ns ; ++i)
300+
// {
301+
// fftw_execute_dft(this->planzbac,(fftw_complex *)&aux1[i*nz],(fftw_complex *)&aux1[i*nz]);
302+
// }
303+
fftw_execute_dft(this->planzbac,(fftw_complex *)in, (fftw_complex *)out);
304+
return;
305+
}
306+
307+
void FFT::fftxyfor(std::complex<double>* & in, std::complex<double>* & out)
308+
{
309+
int bignpy = this->nplane * this-> bigny;
310+
for (int i=0; i<this->nx;++i)
311+
{
312+
fftw_execute_dft( this->planyfor, (fftw_complex *)&in[i*bignpy], (fftw_complex *)&out[i*bignpy]);
313+
}
314+
315+
for (int i=0; i<=this->liy;++i)
316+
{
317+
fftw_execute_dft( this->planxfor, (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
318+
}
319+
for (int i=this->riy; i<this->ny;++i)
320+
{
321+
fftw_execute_dft( this->planxfor, (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
322+
}
323+
return;
324+
}
325+
326+
void FFT::fftxybac(std::complex<double>* & in, std::complex<double>* & out)
327+
{
328+
int bignpy = this->nplane * this-> bigny;
329+
//x-direction
330+
for (int i=0; i<=this->liy;++i)
331+
{
332+
fftw_execute_dft( this->planxbac, (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
333+
}
334+
for (int i=this->riy; i<this->ny;++i)
265335
{
266-
ModuleBase::WARNING_QUIT("FFT", "Wrong input for excutefftw");
336+
fftw_execute_dft( this->planxbac, (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
267337
}
338+
339+
////y-direction
340+
for (int i=0; i<this->nx;++i)
341+
{
342+
fftw_execute_dft( this->planybac, (fftw_complex*)&in[i*bignpy], (fftw_complex*)&out[i*bignpy] );
343+
}
344+
return;
268345
}
269346

347+
void FFT::fftxyr2c(double* &in, std::complex<double>* & out)
348+
{
349+
int bignpy = this->nplane * this-> bigny;
350+
351+
for (int i=0; i<this->nx;++i)
352+
{
353+
fftw_execute_dft_r2c( this->planyr2c, &in[i*bignpy*2], (fftw_complex*)&out[i*bignpy] );
354+
}
355+
356+
for (int i=0; i<=this->liy;++i)
357+
{
358+
fftw_execute_dft( this->planxfor, (fftw_complex *)&out[i*nplane], (fftw_complex *)&out[i*nplane]);
359+
}
360+
for (int i=this->riy; i<this->ny;++i)
361+
{
362+
fftw_execute_dft( this->planxfor, (fftw_complex *)&out[i*nplane], (fftw_complex *)&out[i*nplane]);
363+
}
364+
return;
365+
}
366+
367+
368+
void FFT::fftxyc2r(std::complex<double>* & in, double* & out)
369+
{
370+
int bignpy = this->nplane * this-> bigny;
371+
for (int i=0; i<=this->liy;++i)
372+
{
373+
fftw_execute_dft( this->planxbac, (fftw_complex *)&in[i*nplane], (fftw_complex *)&in[i*nplane]);
374+
}
375+
for (int i=this->riy; i<this->ny;++i)
376+
{
377+
fftw_execute_dft( this->planxbac, (fftw_complex *)&in[i*nplane], (fftw_complex *)&in[i*nplane]);
378+
}
379+
380+
for (int i=0; i<this->nx;++i)
381+
{
382+
fftw_execute_dft_c2r( this->planyc2r, (fftw_complex*)&in[i*bignpy], &out[i*bignpy*2] );
383+
}
384+
return;
385+
}
386+
387+
270388
#ifdef __MIX_PRECISION
271389
void executefftwf(std::string instr)
272390
{
@@ -282,10 +400,6 @@ void executefftwf(std::string instr)
282400
fftwf_execute(this->planf2r2c);
283401
else if(instr == "2c2r")
284402
fftwf_execute(this->planf2c2r);
285-
else
286-
{
287-
ModuleBase::WARNING_QUIT("FFT", "Wrong input for excutefftwf");
288-
}
289403
}
290404
#endif
291405
}

0 commit comments

Comments
 (0)