Skip to content

Commit adce797

Browse files
committed
< prof >
add riy and liy for gamma_only, thus half time is reduce for y fft change rix/lix to riy/liy
1 parent bfebade commit adce797

File tree

7 files changed

+67
-50
lines changed

7 files changed

+67
-50
lines changed

source/module_pw/fft.cpp

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
#include "fft.h"
2-
#include <iostream>
3-
42
namespace ModulePW
53
{
64

@@ -34,7 +32,7 @@ FFT::~FFT()
3432
#endif
3533
}
3634

37-
void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int lix_in, int rix_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool mpifft_in)
35+
void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int liy_in, int riy_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool mpifft_in)
3836
{
3937
this->gamma_only = gamma_only_in;
4038
this->nx = nx_in;
@@ -43,8 +41,8 @@ void FFT:: initfft(int nx_in, int bigny_in, int nz_in, int lix_in, int rix_in, i
4341
else this->ny = this->bigny;
4442
this->nz = nz_in;
4543
this->ns = ns_in;
46-
this->lix = lix_in;
47-
this->rix = rix_in;
44+
this->liy = liy_in;
45+
this->riy = riy_in;
4846
this->nplane = nplane_in;
4947
this->nproc = nproc_in;
5048
this->mpifft = mpifft_in;
@@ -136,11 +134,11 @@ void FFT :: initplan()
136134
// (fftw_complex*) aux2, rankc, this->nplane, 1,
137135
// r_rspace, rankd, this->nplane, 1, FFTW_MEASURE);
138136

139-
int npy = this->nplane * this->ny;
137+
// int npy = this->nplane * this->ny;
140138
int bignpy = this->nplane * this->bigny;
141-
this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, bignpy, 1,
139+
this->planxfor = fftw_plan_many_dft( 1, &this->nx, this->nplane, (fftw_complex *)aux2, embed, bignpy, 1,
142140
(fftw_complex *)aux2, embed, bignpy, 1, FFTW_FORWARD, FFTW_MEASURE );
143-
this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, bignpy, 1,
141+
this->planxbac = fftw_plan_many_dft( 1, &this->nx, this->nplane, (fftw_complex *)aux2, embed, bignpy, 1,
144142
(fftw_complex *)aux2, embed, bignpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
145143
this->planyr2c = fftw_plan_many_dft_r2c( 1, &this->bigny, this->nplane, r_rspace , embed, this->nplane, 1,
146144
(fftw_complex*)aux1, embed, this->nplane, 1, FFTW_MEASURE );
@@ -169,14 +167,14 @@ void FFT :: initplan()
169167
// (fftw_complex*) aux2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE);
170168

171169
//It is better to use in-place for stride > 1
172-
int npy = this->nplane * this->ny;
173-
this->planxfor = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, npy, 1,
174-
(fftw_complex *)aux2, embed, npy, 1, FFTW_FORWARD, FFTW_MEASURE );
175-
this->planxbac = fftw_plan_many_dft( 1, &this->nx, npy, (fftw_complex *)aux2, embed, npy, 1,
176-
(fftw_complex *)aux2, embed, npy, 1, FFTW_BACKWARD, FFTW_MEASURE );
177-
this->planyfor = fftw_plan_many_dft( 1, &this->ny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
170+
int bignpy = this->nplane * this->bigny;
171+
this->planxfor = fftw_plan_many_dft( 1, &this->nx, this->nplane, (fftw_complex *)aux2, embed, bignpy, 1,
172+
(fftw_complex *)aux2, embed, bignpy, 1, FFTW_FORWARD, FFTW_MEASURE );
173+
this->planxbac = fftw_plan_many_dft( 1, &this->nx, this->nplane, (fftw_complex *)aux2, embed, bignpy, 1,
174+
(fftw_complex *)aux2, embed, bignpy, 1, FFTW_BACKWARD, FFTW_MEASURE );
175+
this->planyfor = fftw_plan_many_dft( 1, &this->bigny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
178176
(fftw_complex*)aux2, embed, this->nplane, 1, FFTW_FORWARD, FFTW_MEASURE );
179-
this->planybac = fftw_plan_many_dft( 1, &this->ny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
177+
this->planybac = fftw_plan_many_dft( 1, &this->bigny, this->nplane, (fftw_complex*)aux2 , embed, this->nplane, 1,
180178
(fftw_complex*)aux2, embed, this->nplane, 1, FFTW_BACKWARD, FFTW_MEASURE );
181179
}
182180

@@ -308,61 +306,80 @@ void FFT::fftzbac(std::complex<double>* & in, std::complex<double>* & out)
308306

309307
void FFT::fftxyfor(std::complex<double>* & in, std::complex<double>* & out)
310308
{
311-
int npy = this->nplane * this-> ny;
312-
fftw_execute_dft( this->planxfor, (fftw_complex *)in, (fftw_complex *)out);
313-
for (int i=0; i<=this->lix;++i)
309+
int bignpy = this->nplane * this-> bigny;
310+
for (int i=0; i<this->nx;++i)
314311
{
315-
fftw_execute_dft( this->planyfor, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
312+
fftw_execute_dft( this->planyfor, (fftw_complex *)&in[i*bignpy], (fftw_complex *)&out[i*bignpy]);
316313
}
317-
for (int i=this->rix; i<this->nx;++i)
314+
315+
for (int i=0; i<=this->liy;++i)
318316
{
319-
fftw_execute_dft( this->planyfor, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
317+
fftw_execute_dft( this->planxfor, (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
318+
}
319+
for (int i=this->riy; i<this->ny;++i)
320+
{
321+
fftw_execute_dft( this->planxfor, (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
320322
}
321323
return;
322324
}
323325

324326
void FFT::fftxybac(std::complex<double>* & in, std::complex<double>* & out)
325327
{
326-
int npy = this->nplane * this-> ny;
327-
328-
for (int i=0; i<=this->lix;++i)
328+
int bignpy = this->nplane * this-> bigny;
329+
//x-direction
330+
for (int i=0; i<=this->liy;++i)
331+
{
332+
fftw_execute_dft( this->planxbac, (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
333+
}
334+
for (int i=this->riy; i<this->ny;++i)
329335
{
330-
fftw_execute_dft( this->planybac, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
336+
fftw_execute_dft( this->planxbac, (fftw_complex *)&in[i*nplane], (fftw_complex *)&out[i*nplane]);
331337
}
332-
for (int i=this->rix; i<this->nx;++i)
338+
339+
////y-direction
340+
for (int i=0; i<this->nx;++i)
333341
{
334-
fftw_execute_dft( this->planybac, (fftw_complex*)&in[i*npy], (fftw_complex*)&out[i*npy] );
342+
fftw_execute_dft( this->planybac, (fftw_complex*)&in[i*bignpy], (fftw_complex*)&out[i*bignpy] );
335343
}
336-
fftw_execute_dft( this->planxbac, (fftw_complex *)in, (fftw_complex *)out);
337344
return;
338345
}
339346

340347
void FFT::fftxyr2c(double* &in, std::complex<double>* & out)
341348
{
342-
//int npy = this->nplane * this-> ny;
343349
int bignpy = this->nplane * this-> bigny;
344-
// int padnpy = this->nplane * this-> ny * 2;
350+
345351
for (int i=0; i<this->nx;++i)
346352
{
347353
fftw_execute_dft_r2c( this->planyr2c, &in[i*bignpy*2], (fftw_complex*)&out[i*bignpy] );
348-
if((double*)&out[i*bignpy] != &in[i*bignpy*2] ) std::cout<<i<<"wrond\n";
349-
// fftw_execute_dft_r2c( this->planyfor, &r_rspace[4*i*padnpy], (fftw_complex*)&aux2[i*padnpy] );
350354
}
351-
fftw_execute_dft( this->planxfor, (fftw_complex *)out, (fftw_complex *)out);
355+
356+
for (int i=0; i<=this->liy;++i)
357+
{
358+
fftw_execute_dft( this->planxfor, (fftw_complex *)&out[i*nplane], (fftw_complex *)&out[i*nplane]);
359+
}
360+
for (int i=this->riy; i<this->ny;++i)
361+
{
362+
fftw_execute_dft( this->planxfor, (fftw_complex *)&out[i*nplane], (fftw_complex *)&out[i*nplane]);
363+
}
352364
return;
353365
}
354366

355367

356368
void FFT::fftxyc2r(std::complex<double>* & in, double* & out)
357369
{
358-
//int npy = this->nplane * this-> ny;
359370
int bignpy = this->nplane * this-> bigny;
360-
// int padnpy = this->nplane * this-> ny * 2;
361-
fftw_execute_dft( this->planxbac, (fftw_complex *)in, (fftw_complex *)in);
371+
for (int i=0; i<=this->liy;++i)
372+
{
373+
fftw_execute_dft( this->planxbac, (fftw_complex *)&in[i*nplane], (fftw_complex *)&in[i*nplane]);
374+
}
375+
for (int i=this->riy; i<this->ny;++i)
376+
{
377+
fftw_execute_dft( this->planxbac, (fftw_complex *)&in[i*nplane], (fftw_complex *)&in[i*nplane]);
378+
}
379+
362380
for (int i=0; i<this->nx;++i)
363381
{
364382
fftw_execute_dft_c2r( this->planyc2r, (fftw_complex*)&in[i*bignpy], &out[i*bignpy*2] );
365-
// fftw_execute_dft_c2r( this->planybac, (fftw_complex*)&aux2[i*padnpy], &r_rspace[4*i*padnpy] );
366383
}
367384
return;
368385
}

source/module_pw/fft.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class FFT
2727

2828
FFT();
2929
~FFT();
30-
void initfft(int nx_in, int bigny_in, int nz_in, int lix_in, int rix_in, int ns_in, int nplane_in,
30+
void initfft(int nx_in, int bigny_in, int nz_in, int liy_in, int riy_in, int ns_in, int nplane_in,
3131
int nproc_in, bool gamma_only_in, bool mpifft_in = false);
3232
void setupFFT();
3333
void cleanFFT();
@@ -56,7 +56,7 @@ class FFT
5656
int nxy;
5757
int bigny;
5858
int bignxy;
59-
int lix,rix;// lix: the left edge of the pw ball; rix: the right edge of the pw ball
59+
int liy,riy;// liy: the left edge of the pw ball in the y direction; riy: the right edge of the pw ball in the y direction
6060
int ns; //number of sticks
6161
int nplane; //number of x-y planes
6262
int maxgrids; // max between nz * ns and bignxy * nplane

source/module_pw/pw_basis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void PW_Basis::setuptransform()
5454
this->distribute_r();
5555
this->distribute_g();
5656
this->getstartgr();
57-
this->ft.initfft(this->nx,this->bigny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only);
57+
this->ft.initfft(this->nx,this->bigny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only);
5858
this->ft.setupFFT();
5959
}
6060

source/module_pw/pw_basis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ class PW_Basis
184184
// FFT dimensions for wave functions.
185185
int nx, ny, nz, nxyz, nxy;
186186
int bigny, bignxyz, bignxy; // Gamma_only: ny = int(bigny/2)-1 , others: ny = bigny
187-
int lix,rix;// lix: the left edge of the pw ball; rix: the right edge of the pw ball
187+
int liy,riy;// liy: the left edge of the pw ball; riy: the right edge of the pw ball
188188
int maxgrids; // max between nz * ns and bignxy * nplane
189189
FFT ft;
190190

source/module_pw/pw_distributeg.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void PW_Basis::count_pw_st(
5454
iy_start = 0;
5555
iy_end = this->ny - 1;
5656
}
57-
this->lix = this->rix = 0;
57+
this->liy = this->riy = 0;
5858
for (int ix = -ibox[0]; ix <= ibox[0]; ++ix)
5959
{
6060
for (int iy = iy_start; iy <= iy_end; ++iy)
@@ -82,8 +82,8 @@ void PW_Basis::count_pw_st(
8282
if (length == 0) st_bottom2D[index] = iz; // length == 0 means this point is the bottom of stick (x, y).
8383
++tot_npw;
8484
++length;
85-
if(ix < this->rix) this->rix = ix;
86-
if(ix > this->lix) this->lix = ix;
85+
if(iy < this->riy) this->riy = iy;
86+
if(iy > this->liy) this->liy = iy;
8787
}
8888
}
8989
if (length > 0)
@@ -93,8 +93,8 @@ void PW_Basis::count_pw_st(
9393
}
9494
}
9595
}
96-
if(rix <= 0) rix += this->nx;
97-
std::cout<<"lix "<<lix<<" ; rix "<<rix<<std::endl;
96+
riy += this->ny;
97+
std::cout<<"liy "<<liy<<" ; riy "<<riy<<std::endl;
9898
return;
9999
}
100100

source/module_pw/pw_distributeg_method1.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ void PW_Basis::distribution_method1()
156156
#ifdef __MPI
157157
MPI_Bcast(&tot_npw, 1, MPI_INT, 0, POOL_WORLD);
158158
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, POOL_WORLD);
159-
MPI_Bcast(&lix, 1, MPI_INT, 0, POOL_WORLD);
160-
MPI_Bcast(&rix, 1, MPI_INT, 0, POOL_WORLD);
159+
MPI_Bcast(&liy, 1, MPI_INT, 0, POOL_WORLD);
160+
MPI_Bcast(&riy, 1, MPI_INT, 0, POOL_WORLD);
161161
if (this->poolrank != 0)
162162
{
163163
st_bottom2D = new int[this->nxy]; // minimum z of stick.

source/module_pw/pw_distributeg_method2.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ void PW_Basis::distribution_method2()
105105
#ifdef __MPI
106106
MPI_Bcast(&tot_npw, 1, MPI_INT, 0, POOL_WORLD);
107107
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, POOL_WORLD);
108-
MPI_Bcast(&lix, 1, MPI_INT, 0, POOL_WORLD);
109-
MPI_Bcast(&rix, 1, MPI_INT, 0, POOL_WORLD);
108+
MPI_Bcast(&liy, 1, MPI_INT, 0, POOL_WORLD);
109+
MPI_Bcast(&riy, 1, MPI_INT, 0, POOL_WORLD);
110110
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , POOL_WORLD);
111111
if (this->poolrank != 0)
112112
{

0 commit comments

Comments
 (0)