Skip to content

Commit 6785bdd

Browse files
committed
Add batch function in FFT module
1 parent bc5fd88 commit 6785bdd

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

source/source_base/module_fft/fft_cuda.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,47 @@ void FFT_CUDA<FPTYPE>::initfft(int nx_in, int ny_in, int nz_in, int batch_size)
1616
template <>
1717
void FFT_CUDA<float>::setupFFT()
1818
{
19-
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
20-
resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz);
19+
if (this->batch_size){
20+
int rank = 3; // this means the dimension is 3
21+
int n[3] = {this->nx, this->ny, this->nz};
22+
int inembed[3] = {this->nx, this->ny, this->nz};
23+
int onembed[3] = {this->nx, this->ny, this->nz};
24+
int istride = 1, ostride = 1;
25+
size_t N = static_cast<size_t>(this->nx) * this->ny * this->nz;
26+
int idist = N;
27+
int odist = N;
28+
cufftPlanMany(&c_handle, rank, n,
29+
inembed, istride, idist,
30+
onembed, ostride, odist,
31+
CUFFT_C2C, this->batch_size)
32+
}
33+
else{
34+
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
35+
resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz);
36+
}
37+
2138
}
2239
template <>
2340
void FFT_CUDA<double>::setupFFT()
2441
{
25-
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
26-
resmem_zd_op()(this->z_auxr_3d, this->nx * this->ny * this->nz);
42+
if (this->batch_size){
43+
int rank = 3; // this means the dimension is 3
44+
int n[3] = {this->nx, this->ny, this->nz};
45+
int inembed[3] = {this->nx, this->ny, this->nz};
46+
int onembed[3] = {this->nx, this->ny, this->nz};
47+
int istride = 1, ostride = 1;
48+
size_t N = static_cast<size_t>(this->nx) * this->ny * this->nz;
49+
int idist = N;
50+
int odist = N;
51+
cufftPlanMany(&z_handle, rank, n,
52+
inembed, istride, idist,
53+
onembed, ostride, odist,
54+
CUFFT_Z2Z, this->batch_size)
55+
}
56+
else{
57+
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
58+
resmem_zd_op()(this->z_auxr_3d, this->nx * this->ny * this->nz);
59+
}
2760
}
2861
template <>
2962
void FFT_CUDA<float>::cleanFFT()

0 commit comments

Comments
 (0)