diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index b4ece143ff..78f9824d8b 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -3,6 +3,17 @@ if (ENABLE_FLOAT_FFTW) module_fft/fft_cpu_float.cpp ) endif() +if (USE_CUDA) + list (APPEND FFT_SRC + module_fft/fft_cuda.cpp + ) +endif() +if (USE_ROCM) + list (APPEND FFT_SRC + module_fft/fft_rcom.cpp + ) +endif() + list(APPEND objects fft.cpp pw_basis.cpp diff --git a/source/module_basis/module_pw/fft.cpp b/source/module_basis/module_pw/fft.cpp index fa94bd6442..89495baad6 100644 --- a/source/module_basis/module_pw/fft.cpp +++ b/source/module_basis/module_pw/fft.cpp @@ -91,34 +91,6 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int const int nrxx = this->nxy * this->nplane; const int nsz = this->nz * this->ns; int maxgrids = (nsz > nrxx) ? nsz : nrxx; - if (!this->mpifft) - { - // z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); - // z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); - // ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids); - // d_rspace = (double*)z_auxg; - // auxr_3d = static_cast *>( - // fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz))); -#if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") - { - resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz); - resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz); - } -#endif // defined(__CUDA) || defined(__ROCM) -// #if defined(__ENABLE_FLOAT_FFTW) -// if (this->precision == "single") -// { -// c_auxg = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); -// c_auxr = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); -// ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids); -// s_rspace = (float*)c_auxg; -// } -// #endif // defined(__ENABLE_FLOAT_FFTW) - } - else - { - } } void FFT::setupFFT() diff --git a/source/module_basis/module_pw/module_fft/fft_base.cpp b/source/module_basis/module_pw/module_fft/fft_base.cpp index 4c91d4d7b4..a32a9c0e99 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.cpp +++ b/source/module_basis/module_pw/module_fft/fft_base.cpp @@ -1,8 +1,4 @@ #include "fft_base.h" namespace ModulePW { -template FFT_BASE::FFT_BASE(); -template FFT_BASE::FFT_BASE(); -template FFT_BASE::~FFT_BASE(); -template FFT_BASE::~FFT_BASE(); } \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index a8f4b246aa..b811189971 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -30,6 +30,10 @@ class FFT_BASE bool gamma_only_in, bool xprime_in = true); + virtual __attribute__((weak)) + void initfft(int nx_in, + int ny_in, + int nz_in); /** * @brief Setup the fft Plan and data As pure virtual function. * @@ -159,5 +163,9 @@ class FFT_BASE int ny=0; int nz=0; }; +template FFT_BASE::FFT_BASE(); +template FFT_BASE::FFT_BASE(); +template FFT_BASE::~FFT_BASE(); +template FFT_BASE::~FFT_BASE(); } #endif // FFT_BASE_H diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 1e82e0c595..16ba6481be 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -2,12 +2,13 @@ #include "fft_bundle.h" #include "fft_cpu.h" #include "module_base/module_device/device.h" -// #if defined(__CUDA) -// #include "fft_cuda.h" -// #endif -// #if defined(__ROCM) -// #include "fft_rcom.h" -// #endif + +#if defined(__CUDA) +#include "fft_cuda.h" +#endif +#if defined(__ROCM) +#include "fft_rcom.h" +#endif template std::unique_ptr make_unique(Args &&... args) @@ -21,7 +22,10 @@ void FFT_Bundle::setfft(std::string device_in,std::string precision_in) this->device = device_in; this->precision = precision_in; } - +FFT_Bundle::~FFT_Bundle() +{ + this->clear(); +} void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, @@ -81,15 +85,21 @@ void FFT_Bundle::initfft(int nx_in, xprime_in); } } - if (device=="gpu") + else if (device=="gpu") { - // #if defined(__ROCM) - // fft_float = new FFT_RCOM(); - // fft_double = new FFT_RCOM(); - // #elif defined(__CUDA) - // fft_float = make_unique>(); - // fft_double = make_unique>(); - // #endif + float_flag=true; + double_flag=true; + #if defined(__ROCM) + fft_float = make_unique>; + fft_float->initfft(nx_in,ny_in,nz_in); + fft_double = make_unique>; + fft_double->initfft(nx_in,ny_in,nz_in); + #elif defined(__CUDA) + fft_float = make_unique>(); + fft_float->initfft(nx_in,ny_in,nz_in); + fft_double = make_unique>(); + fft_double->initfft(nx_in,ny_in,nz_in); + #endif } } diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 8321badb4b..6da2419245 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -9,7 +9,7 @@ class FFT_Bundle { public: FFT_Bundle(){}; - ~FFT_Bundle(){}; + ~FFT_Bundle(); /** * @brief Constructor with device and precision. * @param device_in device type, cpu or gpu. diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.cpp b/source/module_basis/module_pw/module_fft/fft_cpu.cpp index be920d4ae2..1aed3e734c 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu.cpp @@ -457,9 +457,4 @@ template <> std::complex* FFT_CPU::get_auxr_data() const {return z_auxr;} template <> std::complex* FFT_CPU::get_auxg_data() const {return z_auxg;} - -template FFT_CPU::FFT_CPU(); -template FFT_CPU::~FFT_CPU(); -template FFT_CPU::FFT_CPU(); -template FFT_CPU::~FFT_CPU(); } \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.h b/source/module_basis/module_pw/module_fft/fft_cpu.h index 27c7e862a2..f04266c14c 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.h +++ b/source/module_basis/module_pw/module_fft/fft_cpu.h @@ -33,7 +33,7 @@ class FFT_CPU : public FFT_BASE * @param gamma_only_in whether only gamma point is used. * @param xprime_in whether xprime is used. */ - __attribute__((weak)) + void initfft(int nx_in, int ny_in, int nz_in, @@ -44,10 +44,10 @@ class FFT_CPU : public FFT_BASE int nproc_in, bool gamma_only_in, bool xprime_in = true) override; + __attribute__((weak)) void setupFFT() override; - // void initplan(const unsigned int& flag = 0); __attribute__((weak)) void cleanFFT() override; @@ -106,31 +106,31 @@ class FFT_CPU : public FFT_BASE void clearfft(fftw_plan& plan); void clearfft(fftwf_plan& plan); - fftw_plan planzfor = NULL; - fftw_plan planzbac = NULL; - fftw_plan planxfor1 = NULL; - fftw_plan planxbac1 = NULL; - fftw_plan planxfor2 = NULL; - fftw_plan planxbac2 = NULL; - fftw_plan planyfor = NULL; - fftw_plan planybac = NULL; - fftw_plan planxr2c = NULL; - fftw_plan planxc2r = NULL; - fftw_plan planyr2c = NULL; - fftw_plan planyc2r = NULL; - - fftwf_plan planfzfor = NULL; - fftwf_plan planfzbac = NULL; - fftwf_plan planfxfor1= NULL; - fftwf_plan planfxbac1= NULL; - fftwf_plan planfxfor2= NULL; - fftwf_plan planfxbac2= NULL; - fftwf_plan planfyfor = NULL; - fftwf_plan planfybac = NULL; - fftwf_plan planfxr2c = NULL; - fftwf_plan planfxc2r = NULL; - fftwf_plan planfyr2c = NULL; - fftwf_plan planfyc2r = NULL; + fftw_plan planzfor = nullptr; + fftw_plan planzbac = nullptr; + fftw_plan planxfor1 = nullptr; + fftw_plan planxbac1 = nullptr; + fftw_plan planxfor2 = nullptr; + fftw_plan planxbac2 = nullptr; + fftw_plan planyfor = nullptr; + fftw_plan planybac = nullptr; + fftw_plan planxr2c = nullptr; + fftw_plan planxc2r = nullptr; + fftw_plan planyr2c = nullptr; + fftw_plan planyc2r = nullptr; + + fftwf_plan planfzfor = nullptr; + fftwf_plan planfzbac = nullptr; + fftwf_plan planfxfor1= nullptr; + fftwf_plan planfxbac1= nullptr; + fftwf_plan planfxfor2= nullptr; + fftwf_plan planfxbac2= nullptr; + fftwf_plan planfyfor = nullptr; + fftwf_plan planfybac = nullptr; + fftwf_plan planfxr2c = nullptr; + fftwf_plan planfxc2r = nullptr; + fftwf_plan planfyr2c = nullptr; + fftwf_plan planfyc2r = nullptr; std::complex*c_auxg = nullptr; std::complex*c_auxr = nullptr; // fft space @@ -169,5 +169,9 @@ class FFT_CPU : public FFT_BASE */ int fft_mode = 0; }; +template FFT_CPU::FFT_CPU(); +template FFT_CPU::~FFT_CPU(); +template FFT_CPU::FFT_CPU(); +template FFT_CPU::~FFT_CPU(); } #endif // FFT_CPU_H \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp index f84b45bf09..b3e8d7d572 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp @@ -267,11 +267,11 @@ void FFT_CPU::setupFFT() } template <> -void FFT_CPU::clearfft(fftw_plan& plan) +void FFT_CPU::clearfft(fftwf_plan& plan) { if (plan) { - fftw_destroy_plan(plan); + fftwf_destroy_plan(plan); plan = nullptr; } } @@ -279,18 +279,18 @@ void FFT_CPU::clearfft(fftw_plan& plan) template <> void FFT_CPU::cleanFFT() { - clearfft(planzfor); - clearfft(planzbac); - clearfft(planxfor1); - clearfft(planxbac1); - clearfft(planxfor2); - clearfft(planxbac2); - clearfft(planyfor); - clearfft(planybac); - clearfft(planxr2c); - clearfft(planxc2r); - clearfft(planyr2c); - clearfft(planyc2r); + clearfft(planfzfor); + clearfft(planfzbac); + clearfft(planfxfor1); + clearfft(planfxbac1); + clearfft(planfxfor2); + clearfft(planfxbac2); + clearfft(planfyfor); + clearfft(planfybac); + clearfft(planfxr2c); + clearfft(planfxc2r); + clearfft(planfyr2c); + clearfft(planfyc2r); } @@ -303,7 +303,7 @@ void FFT_CPU::clear() fftw_free(c_auxg); c_auxg = nullptr; } - if (z_auxr != nullptr) + if (c_auxr != nullptr) { fftw_free(c_auxr); c_auxr = nullptr; diff --git a/source/module_basis/module_pw/module_fft/fft_cuda.cpp b/source/module_basis/module_pw/module_fft/fft_cuda.cpp new file mode 100644 index 0000000000..f9fc5df74b --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_cuda.cpp @@ -0,0 +1,108 @@ +#include "fft_cuda.h" +#include "module_base/module_device/memory_op.h" +#include "module_hamilt_pw/hamilt_pwdft/global.h" +namespace ModulePW +{ +template +void FFT_CUDA::initfft(int nx_in, + int ny_in, + int nz_in) +{ + this->nx = nx_in; + this->ny = ny_in; + this->nz = nz_in; +} +template <> +void FFT_CUDA::setupFFT() +{ + cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C); + resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz); + +} +template <> +void FFT_CUDA::setupFFT() +{ + cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z); + resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz); +} +template <> +void FFT_CUDA::cleanFFT() +{ + if (c_handle) + { + cufftDestroy(c_handle); + c_handle = {}; + } +} +template <> +void FFT_CUDA::cleanFFT() +{ + if (z_handle) + { + cufftDestroy(z_handle); + z_handle = {}; + } +} +template <> +void FFT_CUDA::clear() +{ + this->cleanFFT(); + if (c_auxr_3d != nullptr) + { + delmem_cd_op()(gpu_ctx, c_auxr_3d); + c_auxr_3d = nullptr; + } +} +template <> +void FFT_CUDA::clear() +{ + this->cleanFFT(); + if (z_auxr_3d != nullptr) + { + delmem_zd_op()(gpu_ctx, z_auxr_3d); + z_auxr_3d = nullptr; + } +} + +template <> +void FFT_CUDA::fft3D_forward(std::complex* in, + std::complex* out) const +{ + CHECK_CUFFT(cufftExecC2C(this->c_handle, + reinterpret_cast(in), + reinterpret_cast(out), + CUFFT_FORWARD)); +} +template <> +void FFT_CUDA::fft3D_forward(std::complex* in, + std::complex* out) const +{ + CHECK_CUFFT(cufftExecZ2Z(this->z_handle, + reinterpret_cast(in), + reinterpret_cast(out), + CUFFT_FORWARD)); +} +template <> +void FFT_CUDA::fft3D_backward(std::complex* in, + std::complex* out) const +{ + CHECK_CUFFT(cufftExecC2C(this->c_handle, + reinterpret_cast(in), + reinterpret_cast(out), + CUFFT_INVERSE)); +} + +template <> +void FFT_CUDA::fft3D_backward(std::complex* in, + std::complex* out) const +{ + CHECK_CUFFT(cufftExecZ2Z(this->z_handle, + reinterpret_cast(in), + reinterpret_cast(out), + CUFFT_INVERSE)); +} +template <> std::complex* +FFT_CUDA::get_auxr_3d_data() const {return this->c_auxr_3d;} +template <> std::complex* +FFT_CUDA::get_auxr_3d_data() const {return this->z_auxr_3d;} +}// namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_cuda.h b/source/module_basis/module_pw/module_fft/fft_cuda.h new file mode 100644 index 0000000000..90192d24dc --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_cuda.h @@ -0,0 +1,70 @@ +#include "fft_base.h" +#include "cufft.h" +#include "cuda_runtime.h" + +#ifndef FFT_CUDA_H +#define FFT_CUDA_H +namespace ModulePW +{ +template +class FFT_CUDA : public FFT_BASE +{ + public: + FFT_CUDA(){}; + ~FFT_CUDA(){}; + + void setupFFT() override; + + void clear() override; + + void cleanFFT() override; + + /** + * @brief Initialize the fft parameters + * @param nx_in number of grid points in x direction + * @param ny_in number of grid points in y direction + * @param nz_in number of grid points in z direction + * + */ + void initfft(int nx_in, + int ny_in, + int nz_in) override; + + /** + * @brief Get the real space data + * @return real space data + */ + std::complex* get_auxr_3d_data() const override; + + /** + * @brief Forward FFT in 3D + * @param in input data, complex FPTYPE + * @param out output data, complex FPTYPE + * + * This function performs the forward FFT in 3D. + */ + void fft3D_forward(std::complex* in, + std::complex* out) const override; + /** + * @brief Backward FFT in 3D + * @param in input data, complex FPTYPE + * @param out output data, complex FPTYPE + * + * This function performs the backward FFT in 3D. + */ + void fft3D_backward(std::complex* in, + std::complex* out) const override; + private: + cufftHandle c_handle = {}; + cufftHandle z_handle = {}; + + std::complex* c_auxr_3d = nullptr; // fft space + std::complex* z_auxr_3d = nullptr; // fft space + +}; +template FFT_CUDA::FFT_CUDA(); +template FFT_CUDA::~FFT_CUDA(); +template FFT_CUDA::FFT_CUDA(); +template FFT_CUDA::~FFT_CUDA(); +} // namespace ModulePW +#endif \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_rcom.cpp b/source/module_basis/module_pw/module_fft/fft_rcom.cpp new file mode 100644 index 0000000000..225a6aae7a --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_rcom.cpp @@ -0,0 +1,106 @@ +#include "fft_rcom.h" +#include "module_base/module_device/memory_op.h" +#include "module_hamilt_pw/hamilt_pwdft/global.h" +namespace ModulePW +{ +template +void FFT_RCOM::initfft(int nx_in, + int ny_in, + int nz_in) +{ + this->nx = nx_in; + this->ny = ny_in; + this->nz = nz_in; +} +template <> +void FFT_RCOM::setupFFT() +{ + hipfftPlan3d(&c_handle, this->nx, this->ny, this->nz, HIPFFT_C2C); + resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz); + +} +template <> +void FFT_RCOM::setupFFT() +{ + hipfftPlan3d(&z_handle, this->nx, this->ny, this->nz, HIPFFT_Z2Z); + resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz); +} +template <> +void FFT_RCOM::cleanFFT() +{ + if (c_handle) + { + hipfftDestroy(c_handle); + c_handle = {}; + } +} +template <> +void FFT_RCOM::cleanFFT() +{ + if (z_handle) + { + hipfftDestroy(z_handle); + z_handle = {}; + } +} +template <> +void FFT_RCOM::clear() +{ + this->cleanFFT(); + if (c_auxr_3d != nullptr) + { + delmem_cd_op()(gpu_ctx, c_auxr_3d); + c_auxr_3d = nullptr; + } +} +template <> +void FFT_RCOM::clear() +{ + this->cleanFFT(); + if (z_auxr_3d != nullptr) + { + delmem_zd_op()(gpu_ctx, z_auxr_3d); + z_auxr_3d = nullptr; + } +} +template <> +void FFT_RCOM::fft3D_forward(std::complex* in, + std::complex* out) const +{ + CHECK_CUFFT(hipfftExecC2C(this->c_handle, + reinterpret_cast(in), + reinterpret_cast(out), + HIPFFT_FORWARD)); +} +template <> +void FFT_RCOM::fft3D_forward(std::complex* in, + std::complex* out) const +{ + CHECK_CUFFT(hipfftExecZ2Z(this->z_handle, + reinterpret_cast(in), + reinterpret_cast(out), + HIPFFT_FORWARD)); +} +template <> +void FFT_RCOM::fft3D_backward(std::complex* in, + std::complex* out) const +{ + CHECK_CUFFT(hipfftExecC2C(this->c_handle, + reinterpret_cast(in), + reinterpret_cast(out), + HIPFFT_BACKWARD)); +} +template <> +void FFT_RCOM::fft3D_backward(std::complex* in, + std::complex* out) const +{ + CHECK_CUFFT(hipfftExecZ2Z(this->z_handle, + reinterpret_cast(in), + reinterpret_cast(out), + HIPFFT_BACKWARD)); +} +template <> std::complex* +FFT_RCOM::get_auxr_3d_data() const {return this->c_auxr_3d;} +template <> std::complex* +FFT_RCOM::get_auxr_3d_data() const {return this->z_auxr_3d;} +}// namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_rcom.h b/source/module_basis/module_pw/module_fft/fft_rcom.h new file mode 100644 index 0000000000..64ad13329e --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_rcom.h @@ -0,0 +1,65 @@ +#include "fft_base.h" +#include +#include +#ifndef FFT_ROCM_H +#define FFT_ROCM_H +namespace ModulePW +{ +template +class FFT_ROCM : public FFT_BASE +{ + public: + FFT_ROCM(){}; + ~FFT_ROCM(){}; + + void setupFFT() override; + + void clear() override; + + void cleanFFT() override; + + /** + * @brief Initialize the fft parameters for ROCM + * @param nx_in number of grid points in x direction + * @param ny_in number of grid points in y direction + * @param nz_in number of grid points in z direction + * + */ + void initfft(int nx_in, + int ny_in, + int nz_in) override; + + /** + * @brief Get the real space data + * @return real space data + */ + std::complex* get_auxr_3d_data() const override; + + /** + * @brief Forward FFT in 3D for ROCM + * @param in input data, complex FPTYPE + * @param out output data, complex FPTYPE + */ + void fft3D_forward(std::complex* in, + std::complex* out) const override; + + /** + * @brief Backward FFT in 3D for ROCM + * @param in input data, complex FPTYPE + * @param out output data, complex FPTYPE + */ + void fft3D_backward(std::complex* in, + std::complex* out) const override; + private: + hipfftHandle c_handle = {}; + hipfftHandle z_handle = {}; + mutable std::complex* c_auxr_3d = nullptr; // fft space + mutable std::complex* z_auxr_3d = nullptr; // fft space + +}; +template FFT_RCOM::FFT_RCOM(); +template FFT_ROCM::~FFT_ROCM(); +template FFT_RCOM::FFT_RCOM(); +template FFT_ROCM::~FFT_ROCM(); +}// namespace ModulePW +#endif diff --git a/source/module_basis/module_pw/pw_basis.cpp b/source/module_basis/module_pw/pw_basis.cpp index ac02c45763..7219b6e126 100644 --- a/source/module_basis/module_pw/pw_basis.cpp +++ b/source/module_basis/module_pw/pw_basis.cpp @@ -15,8 +15,6 @@ PW_Basis::PW_Basis() PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) { classname="PW_Basis"; - this->ft.set_device(this->device); - this->ft.set_precision(this->precision); this->fft_bundle.setfft("cpu",this->precision); } @@ -57,27 +55,24 @@ void PW_Basis::setuptransform() this->distribute_r(); this->distribute_g(); this->getstartgr(); - this->ft.clear(); this->fft_bundle.clear(); if(this->xprime) { - this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); } else { - this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); } - this->ft.setupFFT(); this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } void PW_Basis::getstartgr() { - if(this->gamma_only) this->nmaxgr = ( this->npw > (this->nrxx+1)/2 ) ? this->npw : (this->nrxx+1)/2; - else this->nmaxgr = ( this->npw > this->nrxx ) ? this->npw : this->nrxx; + if(this->gamma_only) { this->nmaxgr = ( this->npw > (this->nrxx+1)/2 ) ? this->npw : (this->nrxx+1)/2; + } else { this->nmaxgr = ( this->npw > this->nrxx ) ? this->npw : this->nrxx; +} this->nmaxgr = (this->nz * this->nst > this->nxy * nplane) ? this->nz * this->nst : this->nxy * nplane; //--------------------------------------------- @@ -91,23 +86,27 @@ void PW_Basis::getstartgr() // Each processor has a set of full sticks, // 'rank_use' processor send a piece(npps[ip]) of these sticks(nst_per[rank_use]) // to all the other processors in this pool - for (int ip = 0;ip < poolnproc; ++ip) this->numg[ip] = this->nst_per[poolrank] * this->numz[ip]; + for (int ip = 0;ip < poolnproc; ++ip) { this->numg[ip] = this->nst_per[poolrank] * this->numz[ip]; +} // Each processor in a pool send a piece of each stick(nst_per[ip]) to // other processors in this pool // rank_use processor receive datas in npps[rank_p] planes. - for (int ip = 0;ip < poolnproc; ++ip) this->numr[ip] = this->nst_per[ip] * this->numz[poolrank]; + for (int ip = 0;ip < poolnproc; ++ip) { this->numr[ip] = this->nst_per[ip] * this->numz[poolrank]; +} // startg record the starting 'numg' position in each processor. this->startg[0] = 0; - for (int ip = 1;ip < poolnproc; ++ip) this->startg[ip] = this->startg[ip-1] + this->numg[ip-1]; + for (int ip = 1;ip < poolnproc; ++ip) { this->startg[ip] = this->startg[ip-1] + this->numg[ip-1]; +} // startr record the starting 'numr' position this->startr[0] = 0; - for (int ip = 1;ip < poolnproc; ++ip) this->startr[ip] = this->startr[ip-1] + this->numr[ip-1]; + for (int ip = 1;ip < poolnproc; ++ip) { this->startr[ip] = this->startr[ip-1] + this->numr[ip-1]; +} return; } @@ -118,7 +117,8 @@ void PW_Basis::getstartgr() /// void PW_Basis::collect_local_pw() { - if(this->npw <= 0) return; + if(this->npw <= 0) { return; +} this->ig_gge0 = -1; delete[] this->gg; this->gg = new double[this->npw]; delete[] this->gdirect; this->gdirect = new ModuleBase::Vector3[this->npw]; @@ -133,16 +133,20 @@ void PW_Basis::collect_local_pw() int ixy = this->is2fftixy[is]; int ix = ixy / this->fftny; int iy = ixy % this->fftny; - if (ix >= int(this->nx/2) + 1) ix -= this->nx; - if (iy >= int(this->ny/2) + 1) iy -= this->ny; - if (iz >= int(this->nz/2) + 1) iz -= this->nz; + if (ix >= int(this->nx/2) + 1) { ix -= this->nx; +} + if (iy >= int(this->ny/2) + 1) { iy -= this->ny; +} + if (iz >= int(this->nz/2) + 1) { iz -= this->nz; +} f.x = ix; f.y = iy; f.z = iz; this->gg[ig] = f * (this->GGT * f); this->gdirect[ig] = f; this->gcar[ig] = f * this->G; - if(this->gg[ig] < 1e-8) this->ig_gge0 = ig; + if(this->gg[ig] < 1e-8) { this->ig_gge0 = ig; +} } return; } @@ -154,7 +158,8 @@ void PW_Basis::collect_local_pw() /// void PW_Basis::collect_uniqgg() { - if(this->npw <= 0) return; + if(this->npw <= 0) { return; +} this->ig_gge0 = -1; delete[] this->ig2igg; this->ig2igg = new int [this->npw]; //add by A.s 202406 @@ -170,14 +175,18 @@ void PW_Basis::collect_uniqgg() int ixy = this->is2fftixy[is]; int ix = ixy / this->fftny; int iy = ixy % this->fftny; - if (ix >= int(this->nx/2) + 1) ix -= this->nx; - if (iy >= int(this->ny/2) + 1) iy -= this->ny; - if (iz >= int(this->nz/2) + 1) iz -= this->nz; + if (ix >= int(this->nx/2) + 1) { ix -= this->nx; +} + if (iy >= int(this->ny/2) + 1) { iy -= this->ny; +} + if (iz >= int(this->nz/2) + 1) { iz -= this->nz; +} f.x = ix; f.y = iy; f.z = iz; tmpgg[ig] = f * (this->GGT * f); - if(tmpgg[ig] < 1e-8) this->ig_gge0 = ig; + if(tmpgg[ig] < 1e-8) { this->ig_gge0 = ig; +} } ModuleBase::GlobalFunc::ZEROS(sortindex, this->npw); @@ -221,7 +230,8 @@ void PW_Basis::collect_uniqgg() void PW_Basis::getfftixy2is(int * fftixy2is) const { //Note: please assert when is1 >= is2, fftixy2is[is1] >= fftixy2is[is2]! - for(int ixy = 0 ; ixy < this->fftnxy ; ++ixy) fftixy2is[ixy] = -1; + for(int ixy = 0 ; ixy < this->fftnxy ; ++ixy) { fftixy2is[ixy] = -1; +} int ixy = 0; for(int is = 0; is < this->nst; ++is) { diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index 66f5ff6301..20e101ad8e 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -242,7 +242,7 @@ class PW_Basis int ng_xeq0 = 0; //only used when xprime = true, number of g whose gx = 0 int nmaxgr=0; // Gamma_only: max between npw and (nrxx+1)/2, others: max between npw and nrxx // Thus complex[nmaxgr] is able to contain either reciprocal or real data - FFT ft; + // FFT ft; FFT_Bundle fft_bundle; //The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform). @@ -282,8 +282,7 @@ class PW_Basis using resmem_int_op = base_device::memory::resize_memory_op; using delmem_int_op = base_device::memory::delete_memory_op; - using syncmem_int_h2d_op - = base_device::memory::synchronize_memory_op; + using syncmem_int_h2d_op = base_device::memory::synchronize_memory_op; void set_device(std::string device_); void set_precision(std::string precision_); diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index 2361404d84..079eeaf119 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -12,7 +12,7 @@ namespace ModulePW PW_Basis_K::PW_Basis_K() { classname="PW_Basis_K"; - this->fft_bundle.setfft("cpu",this->precision); + this->fft_bundle.setfft(this->device,this->precision); } PW_Basis_K::~PW_Basis_K() { @@ -184,16 +184,13 @@ void PW_Basis_K::setuptransform() this->distribute_g(); this->getstartgr(); this->setupIndGk(); - this->ft.clear(); this->fft_bundle.clear(); + this->fft_bundle.setfft(this->device,this->precision); if(this->xprime){ - this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); }else{ - this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); } - this->ft.setupFFT(); this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } diff --git a/source/module_basis/module_pw/pw_basis_sup.cpp b/source/module_basis/module_pw/pw_basis_sup.cpp index 80c7e87f57..f7a3caa13b 100644 --- a/source/module_basis/module_pw/pw_basis_sup.cpp +++ b/source/module_basis/module_pw/pw_basis_sup.cpp @@ -19,20 +19,9 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->distribute_r(); this->distribute_g(pw_rho); this->getstartgr(); - this->ft.clear(); this->fft_bundle.clear(); if (this->xprime) { - this->ft.initfft(this->nx, - this->ny, - this->nz, - this->lix, - this->rix, - this->nst, - this->nplane, - this->poolnproc, - this->gamma_only, - this->xprime); this->fft_bundle.initfft(this->nx, this->ny, this->nz, @@ -46,16 +35,6 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) } else { - this->ft.initfft(this->nx, - this->ny, - this->nz, - this->liy, - this->riy, - this->nst, - this->nplane, - this->poolnproc, - this->gamma_only, - this->xprime); this->fft_bundle.initfft(this->nx, this->ny, this->nz, @@ -67,7 +46,6 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->gamma_only, this->xprime); } - this->ft.setupFFT(); this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } @@ -122,8 +100,9 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho) this->npw_per = new int[this->poolnproc]; // number of planewaves on each core. delete[] this->fftixy2ip; this->fftixy2ip = new int[this->fftnxy]; // ip of core which contains the stick on (x, y). - for (int ixy = 0; ixy < this->fftnxy; ++ixy) + for (int ixy = 0; ixy < this->fftnxy; ++ixy) { this->fftixy2ip[ixy] = -1; // meaning this stick has not been distributed or there is no stick on (x, y). +} if (poolrank == 0) { // (1) Count the total number of planewaves (tot_npw) and sticks (this->nstot). @@ -234,10 +213,11 @@ void PW_Basis_Sup::divide_sticks_3( int fftnx_s = nx_s; if (this->gamma_only) { - if (this->xprime) + if (this->xprime) { fftnx_s = int(nx_s / 2) + 1; - else + } else { fftny_s = int(ny_s / 2) + 1; +} } int fftnxy_s = fftnx_s * fftny_s; @@ -247,15 +227,19 @@ void PW_Basis_Sup::divide_sticks_3( { int ix = ixy / fftny_s; int iy = ixy % fftny_s; - if (ix >= int(nx_s / 2) + 1) + if (ix >= int(nx_s / 2) + 1) { ix -= nx_s; - if (iy >= int(ny_s / 2) + 1) +} + if (iy >= int(ny_s / 2) + 1) { iy -= ny_s; +} - if (ix < 0) + if (ix < 0) { ix += nx; - if (iy < 0) +} + if (iy < 0) { iy += ny; +} int index = ix * this->fftny + iy; int ip = fftixy2ip_s[ixy]; if (ip >= 0) @@ -371,8 +355,9 @@ void PW_Basis_Sup::get_ig2isz_is2fftixy( fftixy2is[ixy] = st_move; st_move++; } - if (st_move == this->nst) + if (st_move == this->nst) { break; +} } // distribute planewaves in the same order as smooth grids first. @@ -385,19 +370,25 @@ void PW_Basis_Sup::get_ig2isz_is2fftixy( int ixy = pw_rho->is2fftixy[is]; int ix = ixy / pw_rho->fftny; int iy = ixy % pw_rho->fftny; - if (ix >= int(pw_rho->nx / 2) + 1) + if (ix >= int(pw_rho->nx / 2) + 1) { ix -= pw_rho->nx; - if (iy >= int(pw_rho->ny / 2) + 1) +} + if (iy >= int(pw_rho->ny / 2) + 1) { iy -= pw_rho->ny; - if (iz >= int(pw_rho->nz / 2) + 1) +} + if (iz >= int(pw_rho->nz / 2) + 1) { iz -= pw_rho->nz; +} - if (ix < 0) + if (ix < 0) { ix += this->nx; - if (iy < 0) +} + if (iy < 0) { iy += this->ny; - if (iz < 0) +} + if (iz < 0) { iz += this->nz; +} int ixy_now = ix * this->fftny + iy; int index = ixy_now * this->nz + iz; int is_now = fftixy2is[ixy_now]; @@ -405,8 +396,9 @@ void PW_Basis_Sup::get_ig2isz_is2fftixy( this->ig2isz[ig] = isz_now; pw_filled++; found[index] = true; - if (xprime && ix == 0) + if (xprime && ix == 0) { ng_xeq0++; +} } assert(pw_filled == pw_rho->npw); @@ -419,21 +411,24 @@ void PW_Basis_Sup::get_ig2isz_is2fftixy( for (int iz = zstart; iz < zstart + st_length2D[ixy]; ++iz) { int z = iz; - if (z < 0) + if (z < 0) { z += this->nz; +} if (!found[ixy * this->nz + z]) { found[ixy * this->nz + z] = true; int is = fftixy2is[ixy]; this->ig2isz[pw_filled] = is * this->nz + z; pw_filled++; - if (xprime && ixy / fftny == 0) + if (xprime && ixy / fftny == 0) { ng_xeq0++; +} } } } - if (pw_filled == this->npw) + if (pw_filled == this->npw) { break; +} } delete[] fftixy2is; diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index 88285df119..5e3780eef4 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -347,11 +347,11 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, base_device::memory::synchronize_memory_op, base_device::DEVICE_GPU, base_device::DEVICE_GPU>()( ctx, ctx, - this->ft.get_auxr_3d_data(), + this->fft_bundle.get_auxr_3d_data(), in, this->nrxx); - this->ft.fft3D_forward(ctx, this->ft.get_auxr_3d_data(), this->ft.get_auxr_3d_data()); + this->fft_bundle.fft3D_forward(ctx, this->fft_bundle.get_auxr_3d_data(), this->fft_bundle.get_auxr_3d_data()); const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; @@ -361,7 +361,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, add, factor, this->ig2ixyz_k + startig, - this->ft.get_auxr_3d_data(), + this->fft_bundle.get_auxr_3d_data(), out); ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); } @@ -381,11 +381,11 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, base_device::DEVICE_GPU, base_device::DEVICE_GPU>()(ctx, ctx, - this->ft.get_auxr_3d_data(), + this->fft_bundle.get_auxr_3d_data(), in, this->nrxx); - this->ft.fft3D_forward(ctx, this->ft.get_auxr_3d_data(), this->ft.get_auxr_3d_data()); + this->fft_bundle.fft3D_forward(ctx, this->fft_bundle.get_auxr_3d_data(), this->fft_bundle.get_auxr_3d_data()); const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; @@ -395,7 +395,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, add, factor, this->ig2ixyz_k + startig, - this->ft.get_auxr_3d_data(), + this->fft_bundle.get_auxr_3d_data(), out); ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); } @@ -411,10 +411,10 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); assert(this->gamma_only == false); assert(this->poolnproc == 1); - // ModuleBase::GlobalFunc::ZEROS(ft.get_auxr_3d_data(), this->nxyz); + // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data(), this->nxyz); base_device::memory::set_memory_op, base_device::DEVICE_GPU>()( ctx, - this->ft.get_auxr_3d_data(), + this->fft_bundle.get_auxr_3d_data(), 0, this->nxyz); @@ -425,14 +425,14 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, npw_k, this->ig2ixyz_k + startig, in, - this->ft.get_auxr_3d_data()); - this->ft.fft3D_backward(ctx, this->ft.get_auxr_3d_data(), this->ft.get_auxr_3d_data()); + this->fft_bundle.get_auxr_3d_data()); + this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data(), this->fft_bundle.get_auxr_3d_data()); set_recip_to_real_output_op()(ctx, this->nrxx, add, factor, - this->ft.get_auxr_3d_data(), + this->fft_bundle.get_auxr_3d_data(), out); ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); @@ -448,10 +448,10 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); assert(this->gamma_only == false); assert(this->poolnproc == 1); - // ModuleBase::GlobalFunc::ZEROS(ft.get_auxr_3d_data(), this->nxyz); + // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data(), this->nxyz); base_device::memory::set_memory_op, base_device::DEVICE_GPU>()( ctx, - this->ft.get_auxr_3d_data(), + this->fft_bundle.get_auxr_3d_data(), 0, this->nxyz); @@ -462,14 +462,14 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, npw_k, this->ig2ixyz_k + startig, in, - this->ft.get_auxr_3d_data()); - this->ft.fft3D_backward(ctx, this->ft.get_auxr_3d_data(), this->ft.get_auxr_3d_data()); + this->fft_bundle.get_auxr_3d_data()); + this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data(), this->fft_bundle.get_auxr_3d_data()); set_recip_to_real_output_op()(ctx, this->nrxx, add, factor, - this->ft.get_auxr_3d_data(), + this->fft_bundle.get_auxr_3d_data(), out); ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); diff --git a/source/module_basis/module_pw/test/Makefile b/source/module_basis/module_pw/test/Makefile index 884f0f74c0..6f13e0e0ec 100644 --- a/source/module_basis/module_pw/test/Makefile +++ b/source/module_basis/module_pw/test/Makefile @@ -27,7 +27,7 @@ GTEST_DIR = /home/qianrui/gnucompile/g_gtest HONG = -D__NORMAL INCLUDES = -I. -I../../../ -I../../../module_base/module_container LIBS = -OPTS = -Ofast -march=native -std=c++11 -m64 ${INCLUDES} +OPTS = -Ofast -march=native -std=c++11 -m64 ${INCLUDES} -w -g OBJ_DIR = obj LIBNAME = libpw.a GEN = OFF @@ -106,6 +106,7 @@ VPATH=../../../module_base\ ../../../module_base/module_container/ATen/core\ ../../../module_base/module_container/ATen\ ../../../module_parameter\ +../module_fft\ ../\ MATH_OBJS0=matrix.o\ diff --git a/source/module_basis/module_pw/test_serial/pw_basis_k_test.cpp b/source/module_basis/module_pw/test_serial/pw_basis_k_test.cpp index 86f3db923b..e5fac0ef4c 100644 --- a/source/module_basis/module_pw/test_serial/pw_basis_k_test.cpp +++ b/source/module_basis/module_pw/test_serial/pw_basis_k_test.cpp @@ -46,11 +46,11 @@ TEST_F(PWBasisKTEST,Constructor) EXPECT_EQ(basis_k1.classname,"PW_Basis_K"); EXPECT_EQ(basis_k2.classname,"PW_Basis_K"); EXPECT_EQ(basis_k2.device,"cpu"); - EXPECT_EQ(basis_k2.ft.device,"cpu"); + EXPECT_EQ(basis_k2.fft_bundle.device,"cpu"); EXPECT_EQ(basis_k2.precision,"double"); - EXPECT_EQ(basis_k2.ft.precision,"double"); + EXPECT_EQ(basis_k2.fft_bundle.precision,"double"); ModulePW::PW_Basis_K basis_k3(device_flag, precision_single); - EXPECT_EQ(basis_k3.ft.precision,"single"); + EXPECT_EQ(basis_k3.fft_bundle.precision,"single"); } TEST_F(PWBasisKTEST,Initgrids1) diff --git a/source/module_basis/module_pw/test_serial/pw_basis_test.cpp b/source/module_basis/module_pw/test_serial/pw_basis_test.cpp index 4c118fa769..89a84c43b3 100644 --- a/source/module_basis/module_pw/test_serial/pw_basis_test.cpp +++ b/source/module_basis/module_pw/test_serial/pw_basis_test.cpp @@ -58,8 +58,8 @@ TEST_F(PWBasisTEST,Constructor) EXPECT_EQ(pwb2.classname,"PW_Basis"); EXPECT_EQ(pwb2.device,"cpu"); EXPECT_EQ(pwb2.precision,"double"); - EXPECT_EQ(pwb2.ft.device,"cpu"); - EXPECT_EQ(pwb2.ft.precision,"double"); + EXPECT_EQ(pwb2.fft_bundle.device,"cpu"); + EXPECT_EQ(pwb2.fft_bundle.precision,"double"); } TEST_F(PWBasisTEST,Initgrids1) diff --git a/source/module_elecstate/test/charge_extra_test.cpp b/source/module_elecstate/test/charge_extra_test.cpp index fadacdb327..c1022be8f8 100644 --- a/source/module_elecstate/test/charge_extra_test.cpp +++ b/source/module_elecstate/test/charge_extra_test.cpp @@ -70,6 +70,7 @@ FFT::FFT() FFT::~FFT() { } +FFT_Bundle::~FFT_Bundle(){} void PW_Basis::initgrids(const double lat0_in, const ModuleBase::Matrix3 latvec_in, const double gridecut) { } diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index ea69f172df..39ba4851fd 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -56,7 +56,7 @@ ModulePW::FFT::FFT() ModulePW::FFT::~FFT() { } - +ModulePW::FFT_Bundle::~FFT_Bundle(){} void ModulePW::PW_Basis::initgrids(double, ModuleBase::Matrix3, double) { } diff --git a/source/module_esolver/esolver_fp.cpp b/source/module_esolver/esolver_fp.cpp index 4a19c1d917..02c87facca 100644 --- a/source/module_esolver/esolver_fp.cpp +++ b/source/module_esolver/esolver_fp.cpp @@ -83,7 +83,6 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell) } this->pw_rho->initparameters(false, 4.0 * inp.ecutwfc); - this->pw_rho->ft.fft_mode = inp.fft_mode; this->pw_rho->fft_bundle.initfftmode(inp.fft_mode); this->pw_rho->setuptransform(); this->pw_rho->collect_local_pw(); @@ -109,7 +108,6 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell) this->pw_rhod->initgrids(inp.ref_cell_factor * cell.lat0, cell.latvec, inp.ndx, inp.ndy, inp.ndz); } this->pw_rhod->initparameters(false, inp.ecutrho); - this->pw_rhod->ft.fft_mode = inp.fft_mode; this->pw_rhod->fft_bundle.initfftmode(inp.fft_mode); pw_rhod_sup->setuptransform(this->pw_rho); this->pw_rhod->collect_local_pw(); diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index d4560d2a7c..7d2bb0b3b4 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -246,7 +246,6 @@ void ESolver_KS::before_all_runners(const Input_para& inp, UnitCell& // results #endif - this->pw_wfc->ft.fft_mode = inp.fft_mode; this->pw_wfc->fft_bundle.initfftmode(inp.fft_mode); this->pw_wfc->setuptransform(); diff --git a/source/module_hamilt_general/module_xc/test/CMakeLists.txt b/source/module_hamilt_general/module_xc/test/CMakeLists.txt index 66cf5f9cb0..b93e3a6ddb 100644 --- a/source/module_hamilt_general/module_xc/test/CMakeLists.txt +++ b/source/module_hamilt_general/module_xc/test/CMakeLists.txt @@ -21,7 +21,12 @@ AddTest( ../xc_functional_libxc_wrapper_gcxc.cpp ../xc_functional_libxc_wrapper_xc.cpp ../xc_functional_libxc.cpp ) - +if (USE_CUDA) +list(APPEND FFT_SRC ../../../module_basis/module_pw/module_fft/fft_cuda.cpp) +endif() +if (USE_ROCM) +list(APPEND FFT_SRC ../../../module_basis/module_pw/module_fft/fft_rocm.cpp) +endif() AddTest( TARGET XCTest_GRADCORR LIBS parameter MPI::MPI_CXX Libxc::xc ${math_libs} psi device container @@ -41,6 +46,7 @@ AddTest( ../../../module_basis/module_pw/module_fft/fft_base.cpp ../../../module_basis/module_pw/module_fft/fft_bundle.cpp ../../../module_basis/module_pw/module_fft/fft_cpu.cpp + ${FFT_SRC} ) AddTest( @@ -79,4 +85,5 @@ AddTest( ../../../module_basis/module_pw/module_fft/fft_base.cpp ../../../module_basis/module_pw/module_fft/fft_bundle.cpp ../../../module_basis/module_pw/module_fft/fft_cpu.cpp + ${FFT_SRC} ) \ No newline at end of file diff --git a/source/module_hsolver/test/hsolver_pw_sup.h b/source/module_hsolver/test/hsolver_pw_sup.h index c70025a2c2..6b9d872e34 100644 --- a/source/module_hsolver/test/hsolver_pw_sup.h +++ b/source/module_hsolver/test/hsolver_pw_sup.h @@ -4,6 +4,7 @@ namespace ModulePW { PW_Basis::PW_Basis(){}; PW_Basis::~PW_Basis(){}; +FFT_Bundle::~FFT_Bundle(){}; void PW_Basis::initgrids( const double lat0_in, // unit length (unit in bohr) const ModuleBase::Matrix3