Skip to content

Commit fd6fff7

Browse files
A-006Qianruipkumohanchen
authored
Add Gtest for GPU pw_basis and pw_basis_k (#6087)
* first use of the fft * fix the cuda test * add gtest * change the test * change recip_to_real func * set C2C * set C2C and C2R * set C2C and C2R * add pw_basis_k C2C * add CmakeLists.txt * remove compile * add the define * add delete * modify the include file * add change for the POOL_WORLD * format document * add the mpi_flag_ * fix bug in the mpi set * modify the mpi_flag_ as flase * set pw_test * Revert "set pw_test" This reverts commit 85785a1. * revert setup * add comment * update offset * reset the for * change the mpi_flag * change the compile error * sperate two lines * add change * use template instead of change * remove tem file --------- Co-authored-by: Qianrui Liu <[email protected]> Co-authored-by: Mohan Chen <[email protected]>
1 parent 5f97360 commit fd6fff7

23 files changed

+1079
-149
lines changed

source/module_base/module_device/memory_op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ struct synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVIC
133133
void operator()(FPTYPE* arr_out,
134134
const FPTYPE* arr_in,
135135
const size_t size);
136+
136137
};
137138

138139
template <typename FPTYPE>

source/module_basis/module_pw/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,6 @@ if(BUILD_TESTING)
5858
add_subdirectory(test)
5959
add_subdirectory(test_serial)
6060
add_subdirectory(kernels/test)
61+
add_subdirectory(test_gpu)
6162
endif()
6263
endif()

source/module_basis/module_pw/module_fft/fft_cuda.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
#include "fft_cuda.h"
2+
23
#include "module_base/module_device/memory_op.h"
34
#include "module_hamilt_pw/hamilt_pwdft/global.h"
45

56
namespace ModulePW
67
{
78
template <typename FPTYPE>
8-
void FFT_CUDA<FPTYPE>::initfft(int nx_in,
9-
int ny_in,
10-
int nz_in)
9+
void FFT_CUDA<FPTYPE>::initfft(int nx_in, int ny_in, int nz_in)
1110
{
1211
this->nx = nx_in;
1312
this->ny = ny_in;
@@ -18,9 +17,8 @@ void FFT_CUDA<float>::setupFFT()
1817
{
1918
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
2019
resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz);
21-
2220
}
23-
template <>
21+
template <>
2422
void FFT_CUDA<double>::setupFFT()
2523
{
2624
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
@@ -66,49 +64,51 @@ void FFT_CUDA<double>::clear()
6664
}
6765

6866
template <>
69-
void FFT_CUDA<float>::fft3D_forward(std::complex<float>* in,
70-
std::complex<float>* out) const
67+
void FFT_CUDA<float>::fft3D_forward(std::complex<float>* in, std::complex<float>* out) const
7168
{
72-
CHECK_CUFFT(cufftExecC2C(this->c_handle,
73-
reinterpret_cast<cufftComplex*>(in),
69+
CHECK_CUFFT(cufftExecC2C(this->c_handle,
70+
reinterpret_cast<cufftComplex*>(in),
7471
reinterpret_cast<cufftComplex*>(out),
7572
CUFFT_FORWARD));
7673
}
7774
template <>
78-
void FFT_CUDA<double>::fft3D_forward(std::complex<double>* in,
79-
std::complex<double>* out) const
75+
void FFT_CUDA<double>::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const
8076
{
81-
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
77+
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
8278
reinterpret_cast<cufftDoubleComplex*>(in),
83-
reinterpret_cast<cufftDoubleComplex*>(out),
79+
reinterpret_cast<cufftDoubleComplex*>(out),
8480
CUFFT_FORWARD));
8581
}
8682
template <>
87-
void FFT_CUDA<float>::fft3D_backward(std::complex<float>* in,
88-
std::complex<float>* out) const
83+
void FFT_CUDA<float>::fft3D_backward(std::complex<float>* in, std::complex<float>* out) const
8984
{
90-
CHECK_CUFFT(cufftExecC2C(this->c_handle,
91-
reinterpret_cast<cufftComplex*>(in),
85+
CHECK_CUFFT(cufftExecC2C(this->c_handle,
86+
reinterpret_cast<cufftComplex*>(in),
9287
reinterpret_cast<cufftComplex*>(out),
9388
CUFFT_INVERSE));
9489
}
9590

9691
template <>
97-
void FFT_CUDA<double>::fft3D_backward(std::complex<double>* in,
98-
std::complex<double>* out) const
92+
void FFT_CUDA<double>::fft3D_backward(std::complex<double>* in, std::complex<double>* out) const
9993
{
100-
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
94+
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
10195
reinterpret_cast<cufftDoubleComplex*>(in),
102-
reinterpret_cast<cufftDoubleComplex*>(out),
96+
reinterpret_cast<cufftDoubleComplex*>(out),
10397
CUFFT_INVERSE));
10498
}
105-
template <> std::complex<float>*
106-
FFT_CUDA<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
107-
template <> std::complex<double>*
108-
FFT_CUDA<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
99+
template <>
100+
std::complex<float>* FFT_CUDA<float>::get_auxr_3d_data() const
101+
{
102+
return this->c_auxr_3d;
103+
}
104+
template <>
105+
std::complex<double>* FFT_CUDA<double>::get_auxr_3d_data() const
106+
{
107+
return this->z_auxr_3d;
108+
}
109109

110110
template FFT_CUDA<float>::FFT_CUDA();
111111
template FFT_CUDA<float>::~FFT_CUDA();
112112
template FFT_CUDA<double>::FFT_CUDA();
113113
template FFT_CUDA<double>::~FFT_CUDA();
114-
}// namespace ModulePW
114+
} // namespace ModulePW

source/module_basis/module_pw/pw_basis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ PW_Basis:: ~PW_Basis()
4343
if (this->device == "gpu")
4444
{
4545
delmem_int_op()(this->d_is2fftixy);
46+
delmem_int_op()(this->ig2ixyz_gpu);
4647
}
4748
#endif
4849
}
@@ -59,6 +60,7 @@ void PW_Basis::setuptransform()
5960
this->distribute_g();
6061
this->getstartgr();
6162
this->fft_bundle.clear();
63+
6264
if(this->xprime)
6365
{
6466
this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);

source/module_basis/module_pw/pw_basis.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,11 @@ class PW_Basis
100100
//===============================================
101101
public:
102102
#ifdef __MPI
103-
MPI_Comm pool_world;
103+
MPI_Comm pool_world=MPI_COMM_NULL;
104104
#endif
105105

106106
int *ig2isz=nullptr; // map ig to (is, iz).
107+
int *ig2ixyz_gpu = nullptr;
107108
int *istot2ixy=nullptr; // istot2ixy[is]: iy + ix * ny of is^th stick among all sticks.
108109
int *is2fftixy=nullptr, * d_is2fftixy = nullptr; // is2fftixy[is]: iy + ix * ny of is^th stick among sticks on current proc.
109110
int *fftixy2ip=nullptr; // fftixy2ip[iy + ix * fftny]: ip of proc which contains stick on (ix, iy). if no stick: -1
@@ -352,7 +353,10 @@ class PW_Basis
352353
void recip_to_real(TK* in,
353354
TR* out,
354355
const bool add = false,
355-
const typename GetTypeReal<TK>::type factor = 1.0) const;
356+
const typename GetTypeReal<TK>::type factor = 1.0) const
357+
{
358+
this->recip2real_gpu(in,out,add,factor);
359+
};
356360

357361
// template <typename FPTYPE,
358362
// typename Device,
@@ -383,9 +387,7 @@ class PW_Basis
383387
* values in the output array.
384388
* @param factor Optional scaling factor, default value 1.0, applied to the output values.
385389
*/
386-
template <typename TK,
387-
typename TR,
388-
typename Device,
390+
template <typename TR,typename TK,typename Device,
389391
typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
390392
&& (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value)
391393
&& std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0>
@@ -397,14 +399,17 @@ class PW_Basis
397399
this->real2recip(in, out, add, factor);
398400
}
399401

400-
template <typename TK,typename TR, typename Device,
402+
template <typename TR, typename TK, typename Device,
401403
typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
402404
&& (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value)
403-
&& !std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0>
405+
&& std::is_same<Device, base_device::DEVICE_GPU>::value ,int>::type = 0>
404406
void real_to_recip(TR* in,
405407
TK* out,
406408
const bool add = false,
407-
const typename GetTypeReal<TK>::type factor = 1.0) const;
409+
const typename GetTypeReal<TK>::type factor = 1.0) const
410+
{
411+
this->real2recip_gpu(in,out,add,factor);
412+
};
408413

409414
protected:
410415
//gather planes and scatter sticks of all processors

source/module_basis/module_pw/pw_basis_big.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ class PW_Basis_Big : public PW_Basis_Sup
170170
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , this->pool_world);
171171
#endif
172172

173-
174173
// Find the minimal FFT box size the factors into the primes (2,3,5,7).
175174
for (int i = 0; i < 3; i++)
176175
{

source/module_basis/module_pw/pw_basis_k_big.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ class PW_Basis_K_Big: public PW_Basis_K
5656
for(int ip = 0 ; ip < this->poolnproc ; ++ip)
5757
{
5858
this->numz[ip] = npbz*this->bz;
59-
if(ip < modbz) { this->numz[ip]+=this->bz;
60-
}
59+
if(ip < modbz) { this->numz[ip]+=this->bz;}
6160
if(ip < this->poolnproc - 1) this->startz[ip+1] = this->startz[ip] + numz[ip];
6261
if(ip == this->poolrank)
6362
{

source/module_basis/module_pw/pw_basis_sup.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
114114
this->count_pw_st(st_length2D, st_bottom2D);
115115
}
116116
#ifdef __MPI
117-
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
118-
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
119-
MPI_Bcast(&liy, 1, MPI_INT, 0, this->pool_world);
120-
MPI_Bcast(&riy, 1, MPI_INT, 0, this->pool_world);
121-
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
122-
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
117+
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
118+
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
119+
MPI_Bcast(&liy, 1, MPI_INT, 0, this->pool_world);
120+
MPI_Bcast(&riy, 1, MPI_INT, 0, this->pool_world);
121+
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
122+
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
123123
#endif
124124
delete[] this->istot2ixy;
125125
this->istot2ixy = new int[this->nstot];
@@ -164,14 +164,14 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
164164
}
165165
#endif
166166
}
167-
168167
#ifdef __MPI
169-
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
170-
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
171-
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
172-
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
173-
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0, this->pool_world);
174-
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0, this->pool_world);
168+
169+
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
170+
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
171+
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
172+
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
173+
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0, this->pool_world);
174+
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0, this->pool_world);
175175
#endif
176176
this->npw = this->npw_per[this->poolrank];
177177
this->nst = this->nst_per[this->poolrank];

source/module_basis/module_pw/pw_distributeg.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,23 +189,42 @@ void PW_Basis::get_ig2isz_is2fftixy(
189189
for (int iz = zstart; iz < zstart + st_length2D[ixy]; ++iz)
190190
{
191191
int z = iz;
192-
if (z < 0) { z += this->nz;
193-
}
192+
if (z < 0)
193+
{
194+
z += this->nz;
195+
}
194196
this->ig2isz[pw_filled] = st_move * this->nz + z;
195197
pw_filled++;
196198
}
197199
this->is2fftixy[st_move] = ixy;
198200
st_move++;
199-
if(xprime && ixy/fftny == 0) { ng_xeq0 = pw_filled;
200-
}
201+
if (xprime && ixy / fftny == 0)
202+
{
203+
ng_xeq0 = pw_filled;
204+
}
201205
}
202-
if (st_move == this->nst && pw_filled == this->npw) { break;
203-
}
206+
if (st_move == this->nst && pw_filled == this->npw)
207+
{
208+
break;
209+
}
210+
}
211+
std::vector<int> ig2ixyz(this->npw);
212+
for (int igl = 0; igl < this->npw; ++igl)
213+
{
214+
int isz = this->ig2isz[igl];
215+
int iz = isz % this->nz;
216+
int is = isz / this->nz;
217+
int ixy = this->is2fftixy[is];
218+
int iy = ixy % this->ny;
219+
int ix = ixy / this->ny;
220+
ig2ixyz[igl] = iz + iy * nz + ix * ny * nz;
204221
}
205222
#if defined(__CUDA) || defined(__ROCM)
206223
if (this->device == "gpu") {
207224
resmem_int_op()(d_is2fftixy, this->nst);
208225
syncmem_int_h2d_op()(this->d_is2fftixy, this->is2fftixy, this->nst);
226+
resmem_int_op()(ig2ixyz_gpu,this->npw);
227+
syncmem_int_h2d_op()(ig2ixyz_gpu, ig2ixyz.data(), this->npw);
209228
}
210229
#endif
211230
return;

source/module_basis/module_pw/pw_distributeg_method1.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,15 @@ void PW_Basis::distribution_method1()
4747
this->count_pw_st(st_length2D, st_bottom2D);
4848
}
4949
#ifdef __MPI
50-
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
51-
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
52-
MPI_Bcast(&liy, 1, MPI_INT, 0, this->pool_world);
53-
MPI_Bcast(&riy, 1, MPI_INT, 0, this->pool_world);
54-
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
55-
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
50+
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
51+
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
52+
MPI_Bcast(&liy, 1, MPI_INT, 0, this->pool_world);
53+
MPI_Bcast(&riy, 1, MPI_INT, 0, this->pool_world);
54+
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
55+
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
5656
#endif
57-
delete[] this->istot2ixy; this->istot2ixy = new int[this->nstot];
57+
delete[] this->istot2ixy;
58+
this->istot2ixy = new int[this->nstot];
5859

5960
if(poolrank == 0)
6061
{
@@ -78,7 +79,7 @@ void PW_Basis::distribution_method1()
7879
delete[] st_j;
7980
//We do not need startnsz_per after it.
8081
delete[] this->startnsz_per;
81-
this->startnsz_per=nullptr;
82+
this->startnsz_per=nullptr;
8283
#else
8384
// Serial line
8485
// get nst_per, npw_per, fftixy2ip, and istot2ixy
@@ -96,14 +97,13 @@ void PW_Basis::distribution_method1()
9697
}
9798
#endif
9899
}
99-
100100
#ifdef __MPI
101-
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
102-
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
103-
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
104-
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
105-
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
106-
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
101+
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
102+
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
103+
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
104+
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
105+
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
106+
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0 , this->pool_world);
107107
#endif
108108
this->npw = this->npw_per[this->poolrank];
109109
this->nst = this->nst_per[this->poolrank];

0 commit comments

Comments
 (0)