Skip to content

Commit ecc7f23

Browse files
committed
change recip_to_real func
1 parent 150e63d commit ecc7f23

File tree

1 file changed

+36
-41
lines changed

1 file changed

+36
-41
lines changed

source/module_basis/module_pw/test_gpu/recip_to_real.cpp

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,19 @@ TEST_F(PWTEST,recip_to_real_double)
144144
delete [] h_rhogrout;
145145
}
146146

147-
TEST_F(PWTEST,recip_to_real_double)
147+
TEST_F(PWTEST,recip_to_real_float)
148148
{
149149
cout<<"dividemthd 1, gamma_only: off, check fft between double and complex"<<endl;
150150
ModulePW::PW_Basis pwtest("gpu", precision_flag);
151-
pwtest.fft_bundle.setfft("gpu","double");
151+
pwtest.fft_bundle.setfft("gpu","single");
152152
ModuleBase::Matrix3 latvec(1, 1, 0, 0, 1, 1, 0, 0, 2);
153-
double wfcecut;
153+
double wfcecut = 18;
154154
double lat0 = 2.2;
155155
bool gamma_only=false;
156-
wfcecut = 18;
157156
gamma_only = false;
158157
int distribution_type = 1;
159158
bool xprime = false;
160159

161-
//init
162160
#ifdef __MPI
163161
pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD);
164162
#endif
@@ -173,16 +171,13 @@ TEST_F(PWTEST,recip_to_real_double)
173171
const int nx = pwtest.nx;
174172
const int ny = pwtest.ny;
175173
const int nz = pwtest.nz;
176-
printf("the nx is %d,the ny is %d\n,the nz is %d\n",nx,ny,nz);
177174
const int nplane = pwtest.nplane;
178-
179175
const double tpiba2 = ModuleBase::TWO_PI * ModuleBase::TWO_PI / lat0 / lat0;
180176
const double ggecut = wfcecut / tpiba2;
181-
ModuleBase::Matrix3 GT,G,GGT;
182-
GT = latvec.Inverse();
183-
G = GT.Transpose();
184-
GGT = G * GT;
185-
complex<double> *tmp = new complex<double> [nx*ny*nz];
177+
ModuleBase::Matrix3 GT = latvec.Inverse();
178+
ModuleBase::Matrix3 G = GT.Transpose();
179+
ModuleBase::Matrix3 GGT = G * GT;
180+
complex<float> *tmp = new complex<float> [nx*ny*nz];
186181
if(rank_in_pool == 0)
187182
{
188183
for(int ix = 0 ; ix < nx ; ++ix)
@@ -192,34 +187,34 @@ TEST_F(PWTEST,recip_to_real_double)
192187
for(int iz = 0 ; iz < nz ; ++iz)
193188
{
194189
tmp[ix*ny*nz + iy*nz + iz]=0.0;
195-
double vx = ix - int(nx/2);
196-
double vy = iy - int(ny/2);
197-
double vz = iz - int(nz/2);
190+
float vx = ix - int(nx/2);
191+
float vy = iy - int(ny/2);
192+
float vz = iz - int(nz/2);
198193
ModuleBase::Vector3<double> v(vx,vy,vz);
199-
double modulus = v * (GGT * v);
194+
float modulus = v * (GGT * v);
200195
if (modulus <= ggecut)
201196
{
202197
tmp[ix*ny*nz + iy*nz + iz]=1.0/(modulus+1);
203-
if(vy > 0) tmp[ix*ny*nz + iy*nz + iz]+=ModuleBase::IMAG_UNIT / (std::abs(v.x+1) + 1);
204-
else if(vy < 0) tmp[ix*ny*nz + iy*nz + iz]-=ModuleBase::IMAG_UNIT / (std::abs(-v.x+1) + 1);
198+
if(vy > 0) tmp[ix*ny*nz + iy*nz + iz]+=std::complex<float>(0,1.0) / (std::abs(vx+1) + 1);
199+
else if(vy < 0) tmp[ix*ny*nz + iy*nz + iz]-=std::complex<float>(0,1.0) / (std::abs(-vx+1) + 1);
205200
}
206201
}
207202
}
208203
}
209-
fftw_plan pp = fftw_plan_dft_3d(nx,ny,nz,(fftw_complex *) tmp, (fftw_complex *) tmp, FFTW_BACKWARD, FFTW_ESTIMATE);
210-
fftw_execute(pp);
211-
fftw_destroy_plan(pp);
204+
fftwf_plan pp = fftwf_plan_dft_3d(nx,ny,nz,(fftwf_complex *) tmp, (fftwf_complex *) tmp, FFTW_BACKWARD, FFTW_ESTIMATE);
205+
fftwf_execute(pp);
206+
fftwf_destroy_plan(pp);
212207

213-
ModuleBase::Vector3<double> delta_g(double(int(nx/2))/nx, double(int(ny/2))/ny, double(int(nz/2))/nz);
208+
ModuleBase::Vector3<float> delta_g(float(int(nx/2))/nx, float(int(ny/2))/ny, float(int(nz/2))/nz);
214209
for(int ixy = 0 ; ixy < nx * ny ; ++ixy)
215210
{
216211
for(int iz = 0 ; iz < nz ; ++iz)
217212
{
218213
int ix = ixy / ny;
219214
int iy = ixy % ny;
220-
ModuleBase::Vector3<double> real_r(ix, iy, iz);
221-
double phase_im = -delta_g * real_r;
222-
complex<double> phase(0,ModuleBase::TWO_PI * phase_im);
215+
ModuleBase::Vector3<float> real_r(ix, iy, iz);
216+
float phase_im = -delta_g * real_r;
217+
complex<float> phase(0,ModuleBase::TWO_PI * phase_im);
223218
tmp[ixy * nz + iz] *= exp(phase);
224219
}
225220
}
@@ -228,14 +223,14 @@ TEST_F(PWTEST,recip_to_real_double)
228223
MPI_Bcast(tmp,2*nx*ny*nz,MPI_DOUBLE,0,POOL_WORLD);
229224
#endif
230225
// const int size = nx * ny * nz;
231-
complex<double> * h_rhog = new complex<double> [npw];
232-
complex<double> * h_rhogout = new complex<double> [npw];
233-
complex<double> * d_rhog;
234-
complex<double> * d_rhogr;
235-
complex<double> * d_rhogout;
236-
cudaMalloc((void**)&d_rhog,npw * sizeof(complex<double>));
237-
cudaMalloc((void**)&d_rhogr,npw*sizeof(complex<double>));
238-
cudaMalloc((void**)&d_rhogout,npw*sizeof(complex<double>));
226+
complex<float> * h_rhog = new complex<float> [npw];
227+
complex<float> * h_rhogout = new complex<float> [npw];
228+
complex<float> * d_rhog;
229+
complex<float> * d_rhogr;
230+
complex<float> * d_rhogout;
231+
cudaMalloc((void**)&d_rhog,npw * sizeof(complex<float>));
232+
cudaMalloc((void**)&d_rhogr,npw*sizeof(complex<float>));
233+
cudaMalloc((void**)&d_rhogout,npw*sizeof(complex<float>));
239234

240235
for(int ig = 0 ; ig < npw ; ++ig)
241236
{
@@ -249,15 +244,15 @@ TEST_F(PWTEST,recip_to_real_double)
249244
h_rhog[ig]-=ModuleBase::IMAG_UNIT / (std::abs(-pwtest.gdirect[ig].x+1) + 1);
250245
}
251246
}
252-
cudaMemcpy(d_rhog,h_rhog,npw * sizeof(complex<double>),cudaMemcpyHostToDevice);
253-
cudaMemcpy(d_rhogout,h_rhogout,npw * sizeof(complex<double>),cudaMemcpyHostToDevice);
247+
cudaMemcpy(d_rhog,h_rhog,npw * sizeof(complex<float>),cudaMemcpyHostToDevice);
248+
cudaMemcpy(d_rhogout,h_rhogout,npw * sizeof(complex<float>),cudaMemcpyHostToDevice);
254249

255-
double * h_rhor = new double [nrxx];
256-
double * h_rhogrout = new double [nrxx];
257-
double * d_rhor;
258-
cudaMalloc((void**)&d_rhor,nrxx * sizeof(double));
259-
pwtest.recip_to_real<std::complex<double>,double,base_device::DEVICE_GPU>(d_rhog,d_rhor);
260-
cudaMemcpy(h_rhor,d_rhor,nrxx*sizeof(double),cudaMemcpyDeviceToHost);
250+
float * h_rhor = new float [nrxx];
251+
float * h_rhogrout = new float [nrxx];
252+
float * d_rhor;
253+
cudaMalloc((void**)&d_rhor,nrxx * sizeof(float));
254+
pwtest.recip_to_real<std::complex<float>,float,base_device::DEVICE_GPU>(d_rhog,d_rhor);
255+
cudaMemcpy(h_rhor,d_rhor,nrxx*sizeof(float),cudaMemcpyDeviceToHost);
261256

262257
int startiz = pwtest.startz_current;
263258
for(int ixy = 0 ; ixy < nx * ny ; ++ixy)

0 commit comments

Comments
 (0)