@@ -16,14 +16,47 @@ void FFT_CUDA<FPTYPE>::initfft(int nx_in, int ny_in, int nz_in, int batch_size)
1616template <>
1717void FFT_CUDA<float >::setupFFT()
1818{
19- cufftPlan3d (&c_handle, this ->nx , this ->ny , this ->nz , CUFFT_C2C);
20- resmem_cd_op ()(this ->c_auxr_3d , this ->nx * this ->ny * this ->nz );
19+ if (this ->batch_size ){
20+ int rank = 3 ; // this means the dimension is 3
21+ int n[3 ] = {this ->nx , this ->ny , this ->nz };
22+ int inembed[3 ] = {this ->nx , this ->ny , this ->nz };
23+ int onembed[3 ] = {this ->nx , this ->ny , this ->nz };
24+ int istride = 1 , ostride = 1 ;
25+ size_t N = static_cast <size_t >(this ->nx ) * this ->ny * this ->nz ;
26+ int idist = N;
27+ int odist = N;
28+ cufftPlanMany (&c_handle, rank, n,
29+ inembed, istride, idist,
30+ onembed, ostride, odist,
31+ CUFFT_C2C, this ->batch_size )
32+ }
33+ else {
34+ cufftPlan3d (&c_handle, this ->nx , this ->ny , this ->nz , CUFFT_C2C);
35+ resmem_cd_op ()(this ->c_auxr_3d , this ->nx * this ->ny * this ->nz );
36+ }
37+
2138}
2239template <>
2340void FFT_CUDA<double >::setupFFT()
2441{
25- cufftPlan3d (&z_handle, this ->nx , this ->ny , this ->nz , CUFFT_Z2Z);
26- resmem_zd_op ()(this ->z_auxr_3d , this ->nx * this ->ny * this ->nz );
42+ if (this ->batch_size ){
43+ int rank = 3 ; // this means the dimension is 3
44+ int n[3 ] = {this ->nx , this ->ny , this ->nz };
45+ int inembed[3 ] = {this ->nx , this ->ny , this ->nz };
46+ int onembed[3 ] = {this ->nx , this ->ny , this ->nz };
47+ int istride = 1 , ostride = 1 ;
48+ size_t N = static_cast <size_t >(this ->nx ) * this ->ny * this ->nz ;
49+ int idist = N;
50+ int odist = N;
51+ cufftPlanMany (&z_handle, rank, n,
52+ inembed, istride, idist,
53+ onembed, ostride, odist,
54+ CUFFT_Z2Z, this ->batch_size )
55+ }
56+ else {
57+ cufftPlan3d (&z_handle, this ->nx , this ->ny , this ->nz , CUFFT_Z2Z);
58+ resmem_zd_op ()(this ->z_auxr_3d , this ->nx * this ->ny * this ->nz );
59+ }
2760}
2861template <>
2962void FFT_CUDA<float >::cleanFFT()
0 commit comments