diff --git a/source/source_base/module_fft/fft_base.h b/source/source_base/module_fft/fft_base.h index 1fcbc51412..254a725e3e 100644 --- a/source/source_base/module_fft/fft_base.h +++ b/source/source_base/module_fft/fft_base.h @@ -15,6 +15,7 @@ class FFT_BASE * @brief Initialize the fft parameters as virtual function. * * The function is used to initialize the fft parameters. + * Only FFT on GPU supports batch FFT. So only the second function has the batch_size parameter. */ virtual __attribute__((weak)) void initfft(int nx_in, int ny_in, @@ -27,7 +28,7 @@ class FFT_BASE bool gamma_only_in, bool xprime_in = true); - virtual __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in); + virtual __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in, int batch_size = 0); /** * @brief Setup the fft plan and data as pure virtual function. diff --git a/source/source_base/module_fft/fft_bundle.cpp b/source/source_base/module_fft/fft_bundle.cpp index 67b38364b2..5ca90d936e 100644 --- a/source/source_base/module_fft/fft_bundle.cpp +++ b/source/source_base/module_fft/fft_bundle.cpp @@ -42,6 +42,7 @@ void FFT_Bundle::initfft(int nx_in, int nproc_in, bool gamma_only_in, bool xprime_in, + int batch_size, bool mpifft_in) { assert(this->device == "cpu" || this->device == "gpu" || this->device == "dsp"); diff --git a/source/source_base/module_fft/fft_bundle.h b/source/source_base/module_fft/fft_bundle.h index af82119201..a718f199a7 100644 --- a/source/source_base/module_fft/fft_bundle.h +++ b/source/source_base/module_fft/fft_bundle.h @@ -61,6 +61,7 @@ class FFT_Bundle int nproc_in, bool gamma_only_in, bool xprime_in = true, + int batch_size = 0, bool mpifft_in = false); /** diff --git a/source/source_base/module_fft/fft_cpu.cpp b/source/source_base/module_fft/fft_cpu.cpp index f50f6e9e86..d560c62dac 100644 --- a/source/source_base/module_fft/fft_cpu.cpp +++ b/source/source_base/module_fft/fft_cpu.cpp @@ -12,7 +12,7 @@ void FFT_CPU::initfft(int nx_in, int ns_in, int nplane_in, int nproc_in, - bool gamma_only_in, + bool gamma_only_in, bool xprime_in) { this->gamma_only = gamma_only_in; diff --git a/source/source_base/module_fft/fft_cpu.h b/source/source_base/module_fft/fft_cpu.h index f33fecd74b..fbc3dfaf8b 100644 --- a/source/source_base/module_fft/fft_cpu.h +++ b/source/source_base/module_fft/fft_cpu.h @@ -37,7 +37,7 @@ class FFT_CPU : public FFT_BASE int ns_in, int nplane_in, int nproc_in, - bool gamma_only_in, + bool gamma_only_in, bool xprime_in = true) override; __attribute__((weak)) diff --git a/source/source_base/module_fft/fft_cuda.cpp b/source/source_base/module_fft/fft_cuda.cpp index bd5669f822..28ee4faaca 100644 --- a/source/source_base/module_fft/fft_cuda.cpp +++ b/source/source_base/module_fft/fft_cuda.cpp @@ -6,23 +6,57 @@ namespace ModuleBase { template -void FFT_CUDA::initfft(int nx_in, int ny_in, int nz_in) +void FFT_CUDA::initfft(int nx_in, int ny_in, int nz_in, int batch_size) { this->nx = nx_in; this->ny = ny_in; this->nz = nz_in; + this->batch_size = batch_size; } template <> void FFT_CUDA::setupFFT() { - cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C); - resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz); + if (this->batch_size){ + int rank = 3; // this means the dimension is 3 + int n[3] = {this->nx, this->ny, this->nz}; + int inembed[3] = {this->nx, this->ny, this->nz}; + int onembed[3] = {this->nx, this->ny, this->nz}; + int istride = 1, ostride = 1; + size_t N = static_cast(this->nx) * this->ny * this->nz; + int idist = N; + int odist = N; + cufftPlanMany(&c_handle, rank, n, + inembed, istride, idist, + onembed, ostride, odist, + CUFFT_C2C, this->batch_size); + } + else{ + cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C); + resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz); + } + } template <> void FFT_CUDA::setupFFT() { - cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z); - resmem_zd_op()(this->z_auxr_3d, this->nx * this->ny * this->nz); + if (this->batch_size){ + int rank = 3; // this means the dimension is 3 + int n[3] = {this->nx, this->ny, this->nz}; + int inembed[3] = {this->nx, this->ny, this->nz}; + int onembed[3] = {this->nx, this->ny, this->nz}; + int istride = 1, ostride = 1; + size_t N = static_cast(this->nx) * this->ny * this->nz; + int idist = N; + int odist = N; + cufftPlanMany(&z_handle, rank, n, + inembed, istride, idist, + onembed, ostride, odist, + CUFFT_Z2Z, this->batch_size); + } + else{ + cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z); + resmem_zd_op()(this->z_auxr_3d, this->nx * this->ny * this->nz); + } } template <> void FFT_CUDA::cleanFFT() diff --git a/source/source_base/module_fft/fft_cuda.h b/source/source_base/module_fft/fft_cuda.h index 7734caffa9..f6e41d96be 100644 --- a/source/source_base/module_fft/fft_cuda.h +++ b/source/source_base/module_fft/fft_cuda.h @@ -24,11 +24,13 @@ class FFT_CUDA : public FFT_BASE * @param nx_in number of grid points in x direction * @param ny_in number of grid points in y direction * @param nz_in number of grid points in z direction + * @param batch_size number of batches. Please set to zero if batch FFT is not needed. * */ void initfft(int nx_in, int ny_in, - int nz_in) override; + int nz_in, + int batch_size) override; /** * @brief Get the real space data @@ -61,6 +63,8 @@ class FFT_CUDA : public FFT_BASE std::complex* c_auxr_3d = nullptr; // fft space std::complex* z_auxr_3d = nullptr; // fft space + int batch_size = 0; + }; } // namespace ModuleBase diff --git a/source/source_io/module_parameter/input_parameter.h b/source/source_io/module_parameter/input_parameter.h index f76b48e182..4c80d8e8bd 100644 --- a/source/source_io/module_parameter/input_parameter.h +++ b/source/source_io/module_parameter/input_parameter.h @@ -37,6 +37,7 @@ struct Input_para double ecutrho = 0; ///< energy cutoff for charge/potential int nx = 0, ny = 0, nz = 0; ///< three dimension of FFT wavefunc + int fft_batch = 0; ///< the batch size of FFT on GPU. Set to zero if don't need to use. int ndx = 0, ndy = 0, ndz = 0; ///< three dimension of FFT smooth charge density double cell_factor = 1.2; ///< LiuXh add 20180619 diff --git a/source/source_io/read_input_item_system.cpp b/source/source_io/read_input_item_system.cpp index c9854c9ece..db780703d0 100644 --- a/source/source_io/read_input_item_system.cpp +++ b/source/source_io/read_input_item_system.cpp @@ -386,6 +386,21 @@ void ReadInput::item_system() sync_int(input.nz); this->add_item(item); } + { + Input_Item item("fft_batch"); + item.annotation = "the batch size of FFT on GPU, probably makes cuFFT faster"; + item.read_value = [](const Input_Item& item, Parameter& para) { + para.input.fft_batch = intvalue; + }; + item.check_value = [](const Input_Item& item, const Parameter& para) { + if (para.input.fft_batch < 0) + { + ModuleBase::WARNING_QUIT("ReadInput", "fft_batch should be set to no less than zero"); + } + }; + sync_int(input.fft_batch); + this->add_item(item); + } { Input_Item item("ndx"); item.annotation = "number of points along x axis for FFT smooth grid";