diff --git a/source/Makefile.Objects b/source/Makefile.Objects index dbd695e696..6534d7f268 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -27,6 +27,7 @@ VPATH=./src_global:\ ./module_base/module_mixing:\ ./module_md:\ ./module_basis/module_pw:\ +./module_basis/module_pw/module_fft:\ ./module_esolver:\ ./module_hsolver:\ ./module_hsolver/kernels:\ @@ -168,7 +169,6 @@ OBJS_BASE=abfs-vector3_order.o\ memory_op.o\ device.o\ - OBJS_CELL=atom_pseudo.o\ atom_spec.o\ pseudo.o\ @@ -414,6 +414,9 @@ OBJS_PSI_INITIALIZER=psi_initializer.o\ psi_initializer_nao_random.o\ OBJS_PW=fft.o\ + fft_bundle.o\ + fft_base.o\ + fft_cpu.o\ pw_basis.o\ pw_basis_k.o\ pw_basis_sup.o\ diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index 14deb8213a..712cb902ba 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -6,7 +6,6 @@ list (APPEND LIBM_SRC libm/sincos.cpp ) endif() - add_library( base OBJECT diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index 2b2d897206..b4ece143ff 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -1,3 +1,8 @@ +if (ENABLE_FLOAT_FFTW) + list (APPEND FFT_SRC + module_fft/fft_cpu_float.cpp + ) +endif() list(APPEND objects fft.cpp pw_basis.cpp @@ -10,6 +15,10 @@ list(APPEND objects pw_init.cpp pw_transform.cpp pw_transform_k.cpp + module_fft/fft_base.cpp + module_fft/fft_bundle.cpp + module_fft/fft_cpu.cpp + ${FFT_SRC} ) add_library( diff --git a/source/module_basis/module_pw/fft.cpp b/source/module_basis/module_pw/fft.cpp index 1c56f9b5af..fa94bd6442 100644 --- a/source/module_basis/module_pw/fft.cpp +++ b/source/module_basis/module_pw/fft.cpp @@ -72,10 +72,11 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int this->fftny = this->ny = ny_in; if (this->gamma_only) { - if (xprime) + if (xprime) { this->fftnx = int(nx / 2) + 1; - else + } else { this->fftny = int(ny / 2) + 1; +} } this->nz = nz_in; this->ns = ns_in; @@ -92,10 +93,10 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int int maxgrids = (nsz > nrxx) ? nsz : nrxx; if (!this->mpifft) { - z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); - z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); - ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids); - d_rspace = (double*)z_auxg; + // z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); + // z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); + // ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids); + // d_rspace = (double*)z_auxg; // auxr_3d = static_cast *>( // fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz))); #if defined(__CUDA) || defined(__ROCM) @@ -105,15 +106,15 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz); } #endif // defined(__CUDA) || defined(__ROCM) -#if defined(__ENABLE_FLOAT_FFTW) - if (this->precision == "single") - { - c_auxg = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); - c_auxr = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); - ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids); - s_rspace = (float*)c_auxg; - } -#endif // defined(__ENABLE_FLOAT_FFTW) +// #if defined(__ENABLE_FLOAT_FFTW) +// if (this->precision == "single") +// { +// c_auxg = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); +// c_auxr = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); +// ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids); +// s_rspace = (float*)c_auxg; +// } +// #endif // defined(__ENABLE_FLOAT_FFTW) } else { @@ -353,62 +354,62 @@ void FFT::cleanFFT() if (planzfor) { fftw_destroy_plan(planzfor); - planzfor = NULL; + planzfor = nullptr; } if (planzbac) { fftw_destroy_plan(planzbac); - planzbac = NULL; + planzbac = nullptr; } if (planxfor1) { fftw_destroy_plan(planxfor1); - planxfor1 = NULL; + planxfor1 = nullptr; } if (planxbac1) { fftw_destroy_plan(planxbac1); - planxbac1 = NULL; + planxbac1 = nullptr; } if (planxfor2) { fftw_destroy_plan(planxfor2); - planxfor2 = NULL; + planxfor2 = nullptr; } if (planxbac2) { fftw_destroy_plan(planxbac2); - planxbac2 = NULL; + planxbac2 = nullptr; } if (planyfor) { fftw_destroy_plan(planyfor); - planyfor = NULL; + planyfor = nullptr; } if (planybac) { fftw_destroy_plan(planybac); - planybac = NULL; + planybac = nullptr; } if (planxr2c) { fftw_destroy_plan(planxr2c); - planxr2c = NULL; + planxr2c = nullptr; } if (planxc2r) { fftw_destroy_plan(planxc2r); - planxc2r = NULL; + planxc2r = nullptr; } if (planyr2c) { fftw_destroy_plan(planyr2c); - planyr2c = NULL; + planyr2c = nullptr; } if (planyc2r) { fftw_destroy_plan(planyc2r); - planyc2r = NULL; + planyc2r = nullptr; } // fftw_destroy_plan(this->plan3dforward); diff --git a/source/module_basis/module_pw/module_fft/fft_base.cpp b/source/module_basis/module_pw/module_fft/fft_base.cpp new file mode 100644 index 0000000000..4c91d4d7b4 --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_base.cpp @@ -0,0 +1,8 @@ +#include "fft_base.h" +namespace ModulePW +{ +template FFT_BASE::FFT_BASE(); +template FFT_BASE::FFT_BASE(); +template FFT_BASE::~FFT_BASE(); +template FFT_BASE::~FFT_BASE(); +} \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h new file mode 100644 index 0000000000..a8f4b246aa --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -0,0 +1,163 @@ +#include +#include +#include "fftw3.h" +#ifndef FFT_BASE_H +#define FFT_BASE_H +namespace ModulePW +{ +template +class FFT_BASE +{ +public: + + FFT_BASE(){}; + virtual ~FFT_BASE(){}; + + /** + * @brief Initialize the fft parameters As virtual function. + * + * The function is used to initialize the fft parameters. + */ + virtual __attribute__((weak)) + void initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in = true); + + /** + * @brief Setup the fft Plan and data As pure virtual function. + * + * The function is set as pure virtual function.In order to + * override the function in the derived class.In the derived + * class, the function is used to setup the fft Plan and data. + */ + virtual void setupFFT()=0; + + /** + * @brief Clean the fft Plan As pure virtual function. + * + * The function is set as pure virtual function.In order to + * override the function in the derived class.In the derived + * class, the function is used to clean the fft Plan. + */ + virtual void cleanFFT()=0; + + /** + * @brief Clear the fft data As pure virtual function. + * + * The function is set as pure virtual function.In order to + * override the function in the derived class.In the derived + * class, the function is used to clear the fft data. + */ + virtual void clear()=0; + + /** + * @brief Get the real space data in cpu-like fft + * + * The function is used to get the real space data.While the + * FFT_BASE is an abstract class,the function will be override, + * The attribute weak is used to avoid define the function. + */ + virtual __attribute__((weak)) + FPTYPE* get_rspace_data() const; + + virtual __attribute__((weak)) + std::complex* get_auxr_data() const; + + virtual __attribute__((weak)) + std::complex* get_auxg_data() const; + + /** + * @brief Get the auxiliary real space data in 3D + * + * The function is used to get the auxiliary real space data in 3D. + * While the FFT_BASE is an abstract class,the function will be override, + * The attribute weak is used to avoid define the function. + */ + virtual __attribute__((weak)) + std::complex* get_auxr_3d_data() const; + + //forward fft in x-y direction + + /** + * @brief Forward FFT in x-y direction + * @param in input data + * @param out output data + * + * This function performs the forward FFT in the x-y direction. + * It involves two axes, x and y. The FFT is applied multiple times + * along the left and right boundaries in the primary direction(which is + * determined by the xprime flag).Notably, the Y axis operates in + * "many-many-FFT" mode. + */ + virtual __attribute__((weak)) + void fftxyfor(std::complex* in, + std::complex* out) const; + + virtual __attribute__((weak)) + void fftxybac(std::complex* in, + std::complex* out) const; + + /** + * @brief Forward FFT in z direction + * @param in input data + * @param out output data + * + * This function performs the forward FFT in the z direction. + * It involves only one axis, z. The FFT is applied only once. + * Notably, the Z axis operates in many FFT with nz*ns. + */ + virtual __attribute__((weak)) + void fftzfor(std::complex* in, + std::complex* out) const; + + virtual __attribute__((weak)) + void fftzbac(std::complex* in, + std::complex* out) const; + + /** + * @brief Forward FFT in x-y direction with real to complex + * @param in input data, real type + * @param out output data, complex type + * + * This function performs the forward FFT in the x-y direction + * with real to complex.There is no difference between fftxyfor. + */ + virtual __attribute__((weak)) + void fftxyr2c(FPTYPE* in, + std::complex* out) const; + + virtual __attribute__((weak)) + void fftxyc2r(std::complex* in, + FPTYPE* out) const; + + /** + * @brief Forward FFT in 3D + * @param in input data + * @param out output data + * + * This function performs the forward FFT for gpu-like fft. + * It involves three axes, x, y, and z. The FFT is applied multiple times + * for fft3D_forward. + */ + virtual __attribute__((weak)) + void fft3D_forward(std::complex* in, + std::complex* out) const; + + virtual __attribute__((weak)) + void fft3D_backward(std::complex* in, + std::complex* out) const; + +protected: + int nx=0; + int ny=0; + int nz=0; +}; +} +#endif // FFT_BASE_H diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp new file mode 100644 index 0000000000..1e82e0c595 --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -0,0 +1,212 @@ +#include +#include "fft_bundle.h" +#include "fft_cpu.h" +#include "module_base/module_device/device.h" +// #if defined(__CUDA) +// #include "fft_cuda.h" +// #endif +// #if defined(__ROCM) +// #include "fft_rcom.h" +// #endif + +template +std::unique_ptr make_unique(Args &&... args) +{ + return std::unique_ptr(new FFT_BASE(std::forward(args)...)); +} +namespace ModulePW +{ +void FFT_Bundle::setfft(std::string device_in,std::string precision_in) +{ + this->device = device_in; + this->precision = precision_in; +} + +void FFT_Bundle::initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in , + bool mpifft_in) +{ + assert(this->device=="cpu" || this->device=="gpu"); + assert(this->precision=="single" || this->precision=="double" || this->precision=="mixing"); + + if (this->precision=="single") + { + #ifndef __ENABLE_FLOAT_FFTW + float_define = false; + #endif + float_flag = float_define; + double_flag = true; + } + if (this->precision=="double") + { + double_flag = true; + } + + if (device=="cpu") + { + fft_float = make_unique>(this->fft_mode); + fft_double = make_unique>(this->fft_mode); + if (float_flag) + { + fft_float->initfft(nx_in, + ny_in, + nz_in, + lixy_in, + rixy_in, + ns_in, + nplane_in, + nproc_in, + gamma_only_in, + xprime_in); + } + if (double_flag) + { + fft_double->initfft(nx_in, + ny_in, + nz_in, + lixy_in, + rixy_in, + ns_in, + nplane_in, + nproc_in, + gamma_only_in, + xprime_in); + } + } + if (device=="gpu") + { + // #if defined(__ROCM) + // fft_float = new FFT_RCOM(); + // fft_double = new FFT_RCOM(); + // #elif defined(__CUDA) + // fft_float = make_unique>(); + // fft_double = make_unique>(); + // #endif + } + +} + +void FFT_Bundle::setupFFT() +{ + if (double_flag){fft_double->setupFFT();} + if (float_flag) {fft_float->setupFFT();} +} + +void FFT_Bundle::clearFFT() +{ + if (double_flag){fft_double->cleanFFT();} + if (float_flag) {fft_float->cleanFFT();} +} +void FFT_Bundle::clear() +{ + this->clearFFT(); + if (double_flag){fft_double->clear();} + if (float_flag) {fft_float->clear();} +} + +template <> void +FFT_Bundle::fftxyfor(std::complex* in, + std::complex* out) +const {fft_float->fftxyfor(in,out);} +template <> void +FFT_Bundle::fftxyfor(std::complex* in, + std::complex* out) +const {fft_double->fftxyfor(in,out);} + + +template <> void +FFT_Bundle::fftzfor(std::complex* in, + std::complex* out) +const {fft_float->fftzfor(in,out);} +template <> void +FFT_Bundle::fftzfor(std::complex* in, + std::complex* out) +const {fft_double->fftzfor(in,out);} + +template <> void +FFT_Bundle::fftxybac(std::complex* in, + std::complex* out) +const {fft_float->fftxybac(in,out);} +template <> void +FFT_Bundle::fftxybac(std::complex* in, + std::complex* out) +const {fft_double->fftxybac(in,out);} + +template <> void +FFT_Bundle::fftzbac(std::complex* in, + std::complex* out) +const {fft_float->fftzbac(in,out);} +template <> void +FFT_Bundle::fftzbac(std::complex* in, + std::complex* out) +const {fft_double->fftzbac(in,out);} + +template <> void +FFT_Bundle::fftxyr2c(float* in, + std::complex* out) +const {fft_float->fftxyr2c(in,out);} +template <> void +FFT_Bundle::fftxyr2c(double* in, + std::complex* out) +const {fft_double->fftxyr2c(in,out);} + +template <> void +FFT_Bundle::fftxyc2r(std::complex* in, + float* out) +const {fft_float->fftxyc2r(in,out);} +template <> void +FFT_Bundle::fftxyc2r(std::complex* in, + double* out) +const {fft_double->fftxyc2r(in,out);} + +template <> void +FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) +const {fft_float->fft3D_forward(in, out);} +template <> void +FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) +const {fft_double->fft3D_forward(in, out);} + +template <> void +FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) +const {fft_float->fft3D_backward(in, out);} +template <> void +FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) +const {fft_double->fft3D_backward(in, out);} + +// access the real space data +template <> float* +FFT_Bundle::get_rspace_data() const {return fft_float->get_rspace_data();} +template <> double* +FFT_Bundle::get_rspace_data() const {return fft_double->get_rspace_data();} + +template <> std::complex* +FFT_Bundle::get_auxr_data() const {return fft_float->get_auxr_data();} +template <> std::complex* +FFT_Bundle::get_auxr_data() const {return fft_double->get_auxr_data();} + +template <> std::complex* +FFT_Bundle::get_auxg_data() const {return fft_float->get_auxg_data();} +template <> std::complex* +FFT_Bundle::get_auxg_data() const {return fft_double->get_auxg_data();} + +template <> std::complex* +FFT_Bundle::get_auxr_3d_data() const {return fft_float->get_auxr_3d_data();} +template <> std::complex* +FFT_Bundle::get_auxr_3d_data() const {return fft_double->get_auxr_3d_data();} +} \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h new file mode 100644 index 0000000000..8321badb4b --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -0,0 +1,214 @@ +#include "fft_base.h" +#include +// #include "module_psi/psi.h" +#ifndef FFT_TEMP_H +#define FFT_TEMP_H +namespace ModulePW +{ +class FFT_Bundle +{ + public: + FFT_Bundle(){}; + ~FFT_Bundle(){}; + /** + * @brief Constructor with device and precision. + * @param device_in device type, cpu or gpu. + * @param precision_in precision type, single or double. + * + * the function will check the input device and precision, + * and set the device and precision. + */ + FFT_Bundle(std::string device_in,std::string precision_in) + :device(device_in),precision(precision_in){}; + + /** + * @brief Set device and precision. + * @param device_in device type, cpu or gpu. + * @param precision_in precision type, single or double. + * + * the function will check the input device and precision, + * and set the device and precision. + */ + void setfft(std::string device_in,std::string precision_in); + + /** + * @brief Initialize the fft parameters. + * @param nx_in number of grid points in x direction. + * @param ny_in number of grid points in y direction. + * @param nz_in number of grid points in z direction. + * @param lixy_in the position of the left boundary + * in the x-y plane. + * @param rixy_in the position of the right boundary + * in the x-y plane. + * @param ns_in number of stick whcih is used in the + * Z direction. + * @param nplane_in number of x-y planes. + * @param nproc_in number of processors. + * @param gamma_only_in whether only gamma point is used. + * @param xprime_in whether xprime is used. + * + * the function will initialize the many-fft parameters + * Wheatley in cpu or gpu device. + */ + void initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in = true, + bool mpifft_in = false); + + /** + * @brief Initialize the fft mode. + * @param fft_mode_in fft mode. + * + * the function will initialize the fft mode. + */ + + void initfftmode(int fft_mode_in){this->fft_mode = fft_mode_in;} + + void setupFFT(); + + void clearFFT(); + + void clear(); + + /** + * @brief Get the real space data. + * @return FPTYPE* the real space data. + * + * the function will return the real space data, + * which is used in the cpu-like fft. + */ + template + FPTYPE* get_rspace_data() const; + /** + * @brief Get the auxr data. + * @return std::complex* the auxr data. + * + * the function will return the auxr data, + * which is used in the cpu-like fft. + */ + template + std::complex* get_auxr_data() const; + /** + * @brief Get the auxg data. + * @return std::complex* the auxg data. + * + * the function will return the auxg data, + * which is used in the cpu-like fft. + */ + template + std::complex* get_auxg_data() const; + /** + * @brief Get the auxr 3d data. + * @return std::complex* the auxr 3d data. + * + * the function will return the auxr 3d data, + * which is used in the gpu-like fft. + */ + template + std::complex* get_auxr_3d_data() const; + + /** + * @brief Forward fft in z direction. + * @param in input data. + * @param out output data. + * + * The function will do the forward many fft in z direction, + * As an interface, the function will call the fftzfor in the + * accurate fft class. + * which is used in the cpu-like fft. + */ + template + void fftzfor(std::complex* in, + std::complex* out) const; + /** + * @brief Forward fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the forward fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyfor in the accurate fft class. + */ + template + void fftxyfor(std::complex* in, + std::complex* out) const; + /** + * @brief Backward fft in z direction. + * @param in input data. + * @param out output data. + * + * the function will do the backward many fft in z direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftzbac in the accurate fft class. + */ + template + void fftzbac(std::complex* in, + std::complex* out) const; + /** + * @brief Backward fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the backward fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxybac in the accurate fft class. + */ + template + void fftxybac(std::complex* in, + std::complex* out) const; + + /** + * @brief Real to complex fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the real to complex fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyr2c in the accurate fft class. + */ + template + void fftxyr2c(FPTYPE* in, + std::complex* out) const; + /** + * @brief Complex to real fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the complex to real fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyc2r in the accurate fft class. + */ + template + void fftxyc2r(std::complex* in, + FPTYPE* out) const; + + template + void fft3D_forward(const Device* ctx, + std::complex* in, + std::complex* out) const; + template + void fft3D_backward(const Device* ctx, + std::complex* in, + std::complex* out) const; + + private: + int fft_mode = 0; + bool float_flag=false; + bool float_define=true; + bool double_flag=false; + std::shared_ptr> fft_float=nullptr; + std::shared_ptr> fft_double=nullptr; + + std::string device = "cpu"; + std::string precision = "double"; +}; +} // namespace ModulePW +#endif // FFT_H + diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.cpp b/source/module_basis/module_pw/module_fft/fft_cpu.cpp new file mode 100644 index 0000000000..be920d4ae2 --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_cpu.cpp @@ -0,0 +1,465 @@ +#include "fft_cpu.h" +#include "fftw3.h" +namespace ModulePW +{ + +template +void FFT_CPU::initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in) +{ + this->gamma_only = gamma_only_in; + this->xprime = xprime_in; + this->fftnx = this->nx = nx_in; + this->fftny = this->ny = ny_in; + if (this->gamma_only) + { + if (xprime) { + this->fftnx = int(this->nx / 2) + 1; + } else { + this->fftny = int(this->ny / 2) + 1; + } + } + this->nz = nz_in; + this->ns = ns_in; + this->lixy = lixy_in; + this->rixy = rixy_in; + this->nplane = nplane_in; + this->nproc = nproc_in; + this->nxy = this->nx * this->ny; + this->fftnxy = this->fftnx * this->fftny; + const int nrxx = this->nxy * this->nplane; + const int nsz = this->nz * this->ns; + this->maxgrids = (nsz > nrxx) ? nsz : nrxx; +} +template <> +void FFT_CPU::setupFFT() +{ + + unsigned int flag = FFTW_ESTIMATE; + switch (this->fft_mode) + { + case 0: + flag = FFTW_ESTIMATE; + break; + case 1: + flag = FFTW_MEASURE; + break; + case 2: + flag = FFTW_PATIENT; + break; + case 3: + flag = FFTW_EXHAUSTIVE; + break; + default: + break; + } + z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); + z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); + d_rspace = (double*)z_auxg; + this->planzfor = fftw_plan_many_dft(1, + &this->nz, + this->ns, + (fftw_complex*)z_auxg, + &this->nz, + 1, + this->nz, + (fftw_complex*)z_auxg, + &this->nz, + 1, + this->nz, + FFTW_FORWARD, + flag); + + this->planzbac = fftw_plan_many_dft(1, + &this->nz, + this->ns, + (fftw_complex*)z_auxg, + &this->nz, + 1, + this->nz, + (fftw_complex*)z_auxg, + &this->nz, + 1, + this->nz, + FFTW_BACKWARD, + flag); + + //--------------------------------------------------------- + // 2 D - XY + //--------------------------------------------------------- + // 1D+1D is much faster than 2D FFT! + // in-place fft is better for c2c and out-of-place fft is better for c2r + int* embed = nullptr; + int npy = this->nplane * this->ny; + if (this->xprime) + { + this->planyfor = fftw_plan_many_dft(1, + &this->ny, + this->nplane, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + FFTW_FORWARD, + flag); + this->planybac = fftw_plan_many_dft(1, + &this->ny, + this->nplane, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + FFTW_BACKWARD, + flag); + if (this->gamma_only) + { + this->planxr2c = fftw_plan_many_dft_r2c(1, + &this->nx, + npy, + d_rspace, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + flag); + this->planxc2r = fftw_plan_many_dft_c2r(1, + &this->nx, + npy, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + d_rspace, + embed, + npy, + 1, + flag); + } + else + { + this->planxfor1 = fftw_plan_many_dft(1, + &this->nx, + npy, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planxbac1 = fftw_plan_many_dft(1, + &this->nx, + npy, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); + } + } + else + { + this->planxfor1 = fftw_plan_many_dft(1, + &this->nx, + this->nplane * (this->lixy + 1), + (fftw_complex*)z_auxr, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planxbac1 = fftw_plan_many_dft(1, + &this->nx, + this->nplane * (this->lixy + 1), + (fftw_complex*)z_auxr, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); + if (this->gamma_only) + { + this->planyr2c = fftw_plan_many_dft_r2c(1, + &this->ny, + this->nplane, + d_rspace, + embed, + this->nplane, + 1, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + flag); + this->planyc2r = fftw_plan_many_dft_c2r(1, + &this->ny, + this->nplane, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + d_rspace, + embed, + this->nplane, + 1, + flag); + } + else + { + this->planxfor2 = fftw_plan_many_dft(1, + &this->nx, + this->nplane * (this->ny - this->rixy), + (fftw_complex*)z_auxr, + embed, + npy, + 1, (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planxbac2 = fftw_plan_many_dft(1, + &this->nx, + this->nplane * (this->ny - this->rixy), + (fftw_complex*)z_auxr, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); + this->planyfor = fftw_plan_many_dft(1, + &this->ny, + this->nplane, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + FFTW_FORWARD, + flag); + this->planybac = fftw_plan_many_dft(1, + &this->ny, + this->nplane, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + FFTW_BACKWARD, + flag); + } + } + return; +} + +template <> +void FFT_CPU::clearfft(fftw_plan& plan) +{ + if (plan) + { + fftw_destroy_plan(plan); + plan = nullptr; + } +} + +template <> +void FFT_CPU::cleanFFT() +{ + clearfft(planzfor); + clearfft(planzbac); + clearfft(planxfor1); + clearfft(planxbac1); + clearfft(planxfor2); + clearfft(planxbac2); + clearfft(planyfor); + clearfft(planybac); + clearfft(planxr2c); + clearfft(planxc2r); + clearfft(planyr2c); + clearfft(planyc2r); +} + +template <> +void FFT_CPU::clear() +{ + this->cleanFFT(); + if (z_auxg != nullptr) + { + fftw_free(z_auxg); + z_auxg = nullptr; + } + if (z_auxr != nullptr) + { + fftw_free(z_auxr); + z_auxr = nullptr; + } + d_rspace = nullptr; +} + +template <> +void FFT_CPU::fftxyfor(std::complex* in, std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + for (int i = rixy; i < this->nx; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + } + else + { + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); + fftw_execute_dft(this->planxfor2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); + } +} + +template <> +void FFT_CPU::fftxybac(std::complex* in,std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + for (int i = rixy; i < this->nx; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); + } + else + { + fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); + fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + } +} + +template <> +void FFT_CPU::fftzfor(std::complex* in, std::complex* out) const +{ + fftw_execute_dft(this->planzfor, (fftw_complex*)in, (fftw_complex*)out); +} + +template <> +void FFT_CPU::fftzbac(std::complex* in, std::complex* out) const +{ + fftw_execute_dft(this->planzbac, (fftw_complex*)in, (fftw_complex*)out); +} + +template <> +void FFT_CPU::fftxyr2c(double* in, std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + fftw_execute_dft_r2c(this->planxr2c, in, (fftw_complex*)out); + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]); + } + } + else + { + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft_r2c(this->planyr2c, &in[i * npy], (fftw_complex*)&out[i * npy]); + } + fftw_execute_dft(this->planxfor1, (fftw_complex*)out, (fftw_complex*)out); + } +} + +template <> +void FFT_CPU::fftxyc2r(std::complex *in,double *out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&in[i * npy]); + } + fftw_execute_dft_c2r(this->planxc2r, (fftw_complex*)in, out); + } + else + { + fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)in); + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft_c2r(this->planyc2r, (fftw_complex*)&in[i * npy], &out[i * npy]); + } + } +} + +template <> double* +FFT_CPU::get_rspace_data() const {return d_rspace;} +template <> std::complex* +FFT_CPU::get_auxr_data() const {return z_auxr;} +template <> std::complex* +FFT_CPU::get_auxg_data() const {return z_auxg;} + +template FFT_CPU::FFT_CPU(); +template FFT_CPU::~FFT_CPU(); +template FFT_CPU::FFT_CPU(); +template FFT_CPU::~FFT_CPU(); +} \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.h b/source/module_basis/module_pw/module_fft/fft_cpu.h new file mode 100644 index 0000000000..27c7e862a2 --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_cpu.h @@ -0,0 +1,173 @@ +#include "fft_base.h" +#include "fftw3.h" + +// #ifdef __ENABLE_FLOAT_FFTW + +// #endif +// #endif +#ifndef FFT_CPU_H +#define FFT_CPU_H +namespace ModulePW +{ +template +class FFT_CPU : public FFT_BASE +{ + public: + FFT_CPU(){}; + FFT_CPU(const int fft_mode_in):fft_mode(fft_mode_in){}; + ~FFT_CPU(){}; + + /** + * @brief Initialize the fft parameters. + * @param nx_in number of grid points in x direction. + * @param ny_in number of grid points in y direction. + * @param nz_in number of grid points in z direction. + * @param lixy_in the position of the left boundary + * in the x-y plane. + * @param rixy_in the position of the right boundary + * in the x-y plane. + * @param ns_in number of stick whcih is used in the + * Z direction. + * @param nplane_in number of x-y planes. + * @param nproc_in number of processors. + * @param gamma_only_in whether only gamma point is used. + * @param xprime_in whether xprime is used. + */ + __attribute__((weak)) + void initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in = true) override; + __attribute__((weak)) + void setupFFT() override; + + // void initplan(const unsigned int& flag = 0); + __attribute__((weak)) + void cleanFFT() override; + + __attribute__((weak)) + void clear() override; + + /** + * @brief Get the real space data the CPU FFT. + * @return FPTYPE* the real space data. + * + * the function will return the real space data, + * which is used in the CPU fft.Use the weak attribute + * to avoid defining float while without flag ENABLE_FLOAT_FFTW. + */ + __attribute__((weak)) + FPTYPE* get_rspace_data() const override; + + __attribute__((weak)) + std::complex* get_auxr_data() const override; + + __attribute__((weak)) + std::complex* get_auxg_data() const override; + + /** + * @brief Forward FFT in x-y direction + * @param in input data + * @param out output data + * + * The function details can be found in FFT_BASE, + * and the function interfaces can be found in FFT_BUNDLE. + */ + __attribute__((weak)) + void fftxyfor(std::complex* in, + std::complex* out) const override; + + __attribute__((weak)) + void fftxybac(std::complex* in, + std::complex* out) const override; + + __attribute__((weak)) + void fftzfor(std::complex* in, + std::complex* out) const override; + + __attribute__((weak)) + void fftzbac(std::complex* in, + std::complex* out) const override; + + __attribute__((weak)) + void fftxyr2c(FPTYPE* in, + std::complex* out) const override; + + __attribute__((weak)) + void fftxyc2r(std::complex* in, + FPTYPE* out) const override; + private: + void clearfft(fftw_plan& plan); + void clearfft(fftwf_plan& plan); + + fftw_plan planzfor = NULL; + fftw_plan planzbac = NULL; + fftw_plan planxfor1 = NULL; + fftw_plan planxbac1 = NULL; + fftw_plan planxfor2 = NULL; + fftw_plan planxbac2 = NULL; + fftw_plan planyfor = NULL; + fftw_plan planybac = NULL; + fftw_plan planxr2c = NULL; + fftw_plan planxc2r = NULL; + fftw_plan planyr2c = NULL; + fftw_plan planyc2r = NULL; + + fftwf_plan planfzfor = NULL; + fftwf_plan planfzbac = NULL; + fftwf_plan planfxfor1= NULL; + fftwf_plan planfxbac1= NULL; + fftwf_plan planfxfor2= NULL; + fftwf_plan planfxbac2= NULL; + fftwf_plan planfyfor = NULL; + fftwf_plan planfybac = NULL; + fftwf_plan planfxr2c = NULL; + fftwf_plan planfxc2r = NULL; + fftwf_plan planfyr2c = NULL; + fftwf_plan planfyc2r = NULL; + + std::complex*c_auxg = nullptr; + std::complex*c_auxr = nullptr; // fft space + std::complex*z_auxg = nullptr; + std::complex*z_auxr = nullptr; // fft space + + float* s_rspace = nullptr; // real number space for r, [nplane * nx *ny] + double* d_rspace = nullptr; // real number space for r, [nplane * nx *ny] + int fftnx=0; + int fftny=0; + int fftnxy=0; + int nxy=0; + int nplane=0; + int ns=0; //number of sticks + int nproc=1; // number of proc. + int maxgrids = 0; + bool gamma_only = false; + + /** + * @brief lixy: the left edge of the pw ball in the y direction + */ + int lixy=0; + + /** + * @brief rixy: the right edge of the pw ball in the x or y direction + */ + int rixy=0; + /** + * @brief xprime: whether xprime is used,when do recip2real, x-fft will + * be done last and when doing real2recip, x-fft will be done first; + * false: y-fft For gamma_only, true: we use half x; false: we use half y + */ + bool xprime = true; + /** + * @brief fft_mode: fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive + */ + int fft_mode = 0; +}; +} +#endif // FFT_CPU_H \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp new file mode 100644 index 0000000000..f84b45bf09 --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp @@ -0,0 +1,433 @@ +#include "fft_cpu.h" + +namespace ModulePW +{ +template <> +void FFT_CPU::setupFFT() +{ + unsigned int flag = FFTW_ESTIMATE; + switch (this->fft_mode) + { + case 0: + flag = FFTW_ESTIMATE; + break; + case 1: + flag = FFTW_MEASURE; + break; + case 2: + flag = FFTW_PATIENT; + break; + case 3: + flag = FFTW_EXHAUSTIVE; + break; + default: + break; + } + c_auxg = (std::complex*)fftwf_malloc(sizeof(fftwf_complex) * this->maxgrids); + c_auxr = (std::complex*)fftwf_malloc(sizeof(fftwf_complex) * this->maxgrids); + s_rspace = (float*)c_auxg; + //--------------------------------------------------------- + // 1 D + //--------------------------------------------------------- + + // fftw_plan_many_dft(int rank, + // const int *n, int howmany, + // fftw_complex *in, const int *inembed, int istride, int idist, + // fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned + //flags); + + this->planfzfor = fftwf_plan_many_dft(1, + &this->nz, + this->ns, + (fftwf_complex*)c_auxg, + &this->nz, + 1, + this->nz, + (fftwf_complex*)c_auxg, + &this->nz, + 1, + this->nz, + FFTW_FORWARD, + flag); + + this->planfzbac = fftwf_plan_many_dft(1, + &this->nz, + this->ns, + (fftwf_complex*)c_auxg, + &this->nz, + 1, + this->nz, + (fftwf_complex*)c_auxg, + &this->nz, + 1, + this->nz, + FFTW_BACKWARD, + flag); + //--------------------------------------------------------- + // 2 D + //--------------------------------------------------------- + + int* embed = nullptr; + int npy = this->nplane * this->ny; + if (this->xprime) + { + this->planfyfor = fftwf_plan_many_dft(1, + &this->ny, + this->nplane, + (fftwf_complex*)c_auxr, + embed, + nplane, + 1, + (fftwf_complex*)c_auxr, + embed, + nplane, + 1, + FFTW_FORWARD, + flag); + this->planfybac = fftwf_plan_many_dft(1, + &this->ny, + this->nplane, + (fftwf_complex*)c_auxr, + embed, + nplane, + 1, + (fftwf_complex*)c_auxr, + embed, nplane, + 1, + FFTW_BACKWARD, + flag); + if (this->gamma_only) + { + this->planfxr2c = fftwf_plan_many_dft_r2c(1, + &this->nx, + npy, + s_rspace, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, npy, + 1, + flag); + this->planfxc2r = fftwf_plan_many_dft_c2r(1, + &this->nx, + npy, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + s_rspace, + embed, + npy, + 1, + flag); + } + else + { + this->planfxfor1 = fftwf_plan_many_dft(1, + &this->nx, + npy, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planfxbac1 = fftwf_plan_many_dft(1, + &this->nx, + npy, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); + } + } + else + { + this->planfxfor1 = fftwf_plan_many_dft(1, + &this->nx, + this->nplane * (lixy + 1), + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planfxbac1 = fftwf_plan_many_dft(1, + &this->nx, + this->nplane * (lixy + 1), + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); + if (this->gamma_only) + { + this->planfyr2c = fftwf_plan_many_dft_r2c(1, + &this->ny, + this->nplane, + s_rspace, + embed, + this->nplane, + 1, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + flag); + this->planfyc2r = fftwf_plan_many_dft_c2r(1, + &this->ny, + this->nplane, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + s_rspace, + embed, + this->nplane, + 1, + flag); + } + else + { + this->planfxfor2 = fftwf_plan_many_dft(1, + &this->nx, + this->nplane * (this->ny - rixy), + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planfxbac2 = fftwf_plan_many_dft(1, + &this->nx, + this->nplane * (this->ny - rixy), + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); + this->planfyfor = fftwf_plan_many_dft(1, + &this->ny, + this->nplane, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + FFTW_FORWARD, + flag); + this->planfybac = fftwf_plan_many_dft(1, + &this->ny, + this->nplane, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + FFTW_BACKWARD, + flag); + } + } + return; +} + +template <> +void FFT_CPU::clearfft(fftw_plan& plan) +{ + if (plan) + { + fftw_destroy_plan(plan); + plan = nullptr; + } +} + +template <> +void FFT_CPU::cleanFFT() +{ + clearfft(planzfor); + clearfft(planzbac); + clearfft(planxfor1); + clearfft(planxbac1); + clearfft(planxfor2); + clearfft(planxbac2); + clearfft(planyfor); + clearfft(planybac); + clearfft(planxr2c); + clearfft(planxc2r); + clearfft(planyr2c); + clearfft(planyc2r); +} + + +template <> +void FFT_CPU::clear() +{ + this->cleanFFT(); + if (c_auxg != nullptr) + { + fftw_free(c_auxg); + c_auxg = nullptr; + } + if (z_auxr != nullptr) + { + fftw_free(c_auxr); + c_auxr = nullptr; + } + s_rspace = nullptr; +} + + +template <> +void FFT_CPU::fftxyfor(std::complex* in, std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)in, (fftwf_complex*)out); + + for (int i = 0; i < this->lixy + 1; ++i) + { + fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); + } + for (int i = rixy; i < this->nx; ++i) + { + fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); + } + } + else + { + for (int i = 0; i < this->nx; ++i) + { + fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); + } + + fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)in, (fftwf_complex*)out); + fftwf_execute_dft(this->planfxfor2, (fftwf_complex*)&in[rixy * nplane], (fftwf_complex*)&out[rixy * nplane]); + } +} +template <> +void FFT_CPU::fftxybac(std::complex* in,std::complex * out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + for (int i = 0; i < this->lixy + 1; ++i) + { + fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); + } + for (int i = rixy; i < this->nx; ++i) + { + fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); + } + + fftwf_execute_dft(this->planfxbac1, (fftwf_complex*)in, (fftwf_complex*)out); + } + else + { + fftwf_execute_dft(this->planfxbac1, (fftwf_complex*)in, (fftwf_complex*)out); + fftwf_execute_dft(this->planfxbac2, (fftwf_complex*)&in[rixy * nplane], (fftwf_complex*)&out[rixy * nplane]); + + for (int i = 0; i < this->nx; ++i) + { + fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); + } + } +} +template <> +void FFT_CPU::fftzfor(std::complex* in, std::complex* out) const +{ + fftwf_execute_dft(this->planfzfor, (fftwf_complex*)in, (fftwf_complex*)out); +} +template <> +void FFT_CPU::fftzbac(std::complex* in, std::complex* out) const +{ + fftwf_execute_dft(this->planfzbac, (fftwf_complex*)in, (fftwf_complex*)out); +} +template <> +void FFT_CPU::fftxyr2c(float* in, std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + fftwf_execute_dft_r2c(this->planfxr2c, in, (fftwf_complex*)out); + + for (int i = 0; i < this->lixy + 1; ++i) + { + fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&out[i * npy], (fftwf_complex*)&out[i * npy]); + } + } + else + { + for (int i = 0; i < this->nx; ++i) + { + fftwf_execute_dft_r2c(this->planfyr2c, &in[i * npy], (fftwf_complex*)&out[i * npy]); + } + + fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)out, (fftwf_complex*)out); + } +} +template <> +void FFT_CPU::fftxyc2r(std::complex* in, float* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + for (int i = 0; i < this->lixy + 1; ++i) + { + fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&in[i * npy]); + } + + fftwf_execute_dft_c2r(this->planfxc2r, (fftwf_complex*)in, out); + } + else + { + fftwf_execute_dft(this->planfxbac1, (fftwf_complex*)in, (fftwf_complex*)in); + + for (int i = 0; i < this->nx; ++i) + { + fftwf_execute_dft_c2r(this->planfyc2r, (fftwf_complex*)&in[i * npy], &out[i * npy]); + } + } +} +template <> float* +FFT_CPU::get_rspace_data() const {return s_rspace;} +template <> std::complex* +FFT_CPU::get_auxr_data() const {return c_auxr;} +template <> std::complex* +FFT_CPU::get_auxg_data() const {return c_auxg;} +} \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_basis.cpp b/source/module_basis/module_pw/pw_basis.cpp index 0121eef9e4..ac02c45763 100644 --- a/source/module_basis/module_pw/pw_basis.cpp +++ b/source/module_basis/module_pw/pw_basis.cpp @@ -17,6 +17,7 @@ PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::mo classname="PW_Basis"; this->ft.set_device(this->device); this->ft.set_precision(this->precision); + this->fft_bundle.setfft("cpu",this->precision); } PW_Basis:: ~PW_Basis() @@ -57,9 +58,19 @@ void PW_Basis::setuptransform() this->distribute_g(); this->getstartgr(); this->ft.clear(); - if(this->xprime) this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); - else this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->fft_bundle.clear(); + if(this->xprime) + { + this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + } + else + { + this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + } this->ft.setupFFT(); + this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index 6f95343b1a..66f5ff6301 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -6,6 +6,7 @@ #include "module_base/vector3.h" #include #include "fft.h" +#include "module_fft/fft_bundle.h" #include #ifdef __MPI #include "mpi.h" @@ -242,6 +243,7 @@ class PW_Basis int nmaxgr=0; // Gamma_only: max between npw and (nrxx+1)/2, others: max between npw and nrxx // Thus complex[nmaxgr] is able to contain either reciprocal or real data FFT ft; + FFT_Bundle fft_bundle; //The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform). template diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index c71163eba2..2361404d84 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -12,6 +12,7 @@ namespace ModulePW PW_Basis_K::PW_Basis_K() { classname="PW_Basis_K"; + this->fft_bundle.setfft("cpu",this->precision); } PW_Basis_K::~PW_Basis_K() { @@ -69,7 +70,8 @@ void PW_Basis_K:: initparameters( this->kvec_d[ik] = kvec_d_in[ik]; this->kvec_c[ik] = this->kvec_d[ik] * this->G; double kmod = sqrt(this->kvec_c[ik] * this->kvec_c[ik]); - if(kmod > kmaxmod) kmaxmod = kmod; + if(kmod > kmaxmod) { kmaxmod = kmod; +} } this->gk_ecut = gk_ecut_in/this->tpiba2; this->ggecut = pow(sqrt(this->gk_ecut) + kmaxmod, 2); @@ -80,14 +82,16 @@ void PW_Basis_K:: initparameters( } this->gamma_only = gamma_only_in; - if(kmaxmod > 0) this->gamma_only = false; //if it is not the gamma point, we do not use gamma_only + if(kmaxmod > 0) { this->gamma_only = false; //if it is not the gamma point, we do not use gamma_only +} this->xprime = xprime_in; this->fftny = this->ny; this->fftnx = this->nx; if (this->gamma_only) { - if(this->xprime) this->fftnx = int(this->nx / 2) + 1; - else this->fftny = int(this->ny / 2) + 1; + if(this->xprime) { this->fftnx = int(this->nx / 2) + 1; + } else { this->fftny = int(this->ny / 2) + 1; +} } this->fftnz = this->nz; this->fftnxy = this->fftnx * this->fftny; @@ -141,7 +145,8 @@ void PW_Basis_K::setupIndGk() //get igl2isz_k and igl2ig_k - if(this->npwk_max <= 0) return; + if(this->npwk_max <= 0) { return; +} delete[] igl2isz_k; this->igl2isz_k = new int [this->nks * this->npwk_max]; delete[] igl2ig_k; this->igl2ig_k = new int [this->nks * this->npwk_max]; for (int ik = 0; ik < this->nks; ik++) @@ -180,9 +185,16 @@ void PW_Basis_K::setuptransform() this->getstartgr(); this->setupIndGk(); this->ft.clear(); - if(this->xprime) this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); - else this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->fft_bundle.clear(); + if(this->xprime){ + this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + }else{ + this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + } this->ft.setupFFT(); + this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } @@ -191,7 +203,8 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h this->erf_ecut = erf_ecut_in; this->erf_height = erf_height_in; this->erf_sigma = erf_sigma_in; - if(this->npwk_max <= 0) return; + if(this->npwk_max <= 0) { return; +} delete[] gk2; delete[] gcar; this->gk2 = new double[this->npwk_max * this->nks]; @@ -211,9 +224,12 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h int ixy = this->is2fftixy[is]; int ix = ixy / this->fftny; int iy = ixy % this->fftny; - if (ix >= int(this->nx/2) + 1) ix -= this->nx; - if (iy >= int(this->ny/2) + 1) iy -= this->ny; - if (iz >= int(this->nz/2) + 1) iz -= this->nz; + if (ix >= int(this->nx/2) + 1) { ix -= this->nx; +} + if (iy >= int(this->ny/2) + 1) { iy -= this->ny; +} + if (iz >= int(this->nz/2) + 1) { iz -= this->nz; +} f.x = ix; f.y = iy; f.z = iz; @@ -270,9 +286,12 @@ ModuleBase::Vector3 PW_Basis_K:: cal_GplusK_cartesian(const int ik, cons int is = isz / this->nz; int ix = this->is2fftixy[is] / this->fftny; int iy = this->is2fftixy[is] % this->fftny; - if (ix >= int(this->nx/2) + 1) ix -= this->nx; - if (iy >= int(this->ny/2) + 1) iy -= this->ny; - if (iz >= int(this->nz/2) + 1) iz -= this->nz; + if (ix >= int(this->nx/2) + 1) { ix -= this->nx; +} + if (iy >= int(this->ny/2) + 1) { iy -= this->ny; +} + if (iz >= int(this->nz/2) + 1) { iz -= this->nz; +} ModuleBase::Vector3 f; f.x = ix; f.y = iy; @@ -354,7 +373,8 @@ std::vector PW_Basis_K::get_ig2ix(const int ik) const int is = isz / this->nz; int ixy = this->is2fftixy[is]; int ix = ixy / this->ny; - if (ix < (nx / 2) + 1) ix += nx; + if (ix < (nx / 2) + 1) { ix += nx; +} ig_to_ix[ig] = ix; } return ig_to_ix; @@ -371,7 +391,8 @@ std::vector PW_Basis_K::get_ig2iy(const int ik) const int is = isz / this->nz; int ixy = this->is2fftixy[is]; int iy = ixy % this->ny; - if (iy < (ny / 2) + 1) iy += ny; + if (iy < (ny / 2) + 1) { iy += ny; +} ig_to_iy[ig] = iy; } return ig_to_iy; @@ -386,7 +407,8 @@ std::vector PW_Basis_K::get_ig2iz(const int ik) const { int isz = this->igl2isz_k[ig + ik * npwk_max]; int iz = isz % this->nz; - if (iz < (nz / 2) + 1) iz += nz; + if (iz < (nz / 2) + 1) { iz += nz; +} ig_to_iz[ig] = iz; } return ig_to_iz; diff --git a/source/module_basis/module_pw/pw_basis_sup.cpp b/source/module_basis/module_pw/pw_basis_sup.cpp index 6763f94c07..80c7e87f57 100644 --- a/source/module_basis/module_pw/pw_basis_sup.cpp +++ b/source/module_basis/module_pw/pw_basis_sup.cpp @@ -20,7 +20,9 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->distribute_g(pw_rho); this->getstartgr(); this->ft.clear(); + this->fft_bundle.clear(); if (this->xprime) + { this->ft.initfft(this->nx, this->ny, this->nz, @@ -31,7 +33,19 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->poolnproc, this->gamma_only, this->xprime); + this->fft_bundle.initfft(this->nx, + this->ny, + this->nz, + this->lix, + this->rix, + this->nst, + this->nplane, + this->poolnproc, + this->gamma_only, + this->xprime); + } else + { this->ft.initfft(this->nx, this->ny, this->nz, @@ -42,7 +56,19 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->poolnproc, this->gamma_only, this->xprime); + this->fft_bundle.initfft(this->nx, + this->ny, + this->nz, + this->liy, + this->riy, + this->nst, + this->nplane, + this->poolnproc, + this->gamma_only, + this->xprime); + } this->ft.setupFFT(); + this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp index ca523ace85..d8534c7f0a 100644 --- a/source/module_basis/module_pw/pw_transform.cpp +++ b/source/module_basis/module_pw/pw_transform.cpp @@ -1,4 +1,5 @@ #include "fft.h" +#include "module_fft/fft_bundle.h" #include #include "pw_basis.h" #include @@ -28,13 +29,13 @@ void PW_Basis::real2recip(const std::complex* in, #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - this->ft.get_auxr_data()[ir] = in[ir]; + this->fft_bundle.get_auxr_data()[ir] = in[ir]; } - this->ft.fftxyfor(ft.get_auxr_data(),ft.get_auxr_data()); + this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(),fft_bundle.get_auxr_data()); - this->gatherp_scatters(this->ft.get_auxr_data(), this->ft.get_auxg_data()); + this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); - this->ft.fftzfor(ft.get_auxg_data(),ft.get_auxg_data()); + this->fft_bundle.fftzfor(fft_bundle.get_auxg_data(),fft_bundle.get_auxg_data()); if(add) { @@ -44,7 +45,7 @@ void PW_Basis::real2recip(const std::complex* in, #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] += tmpfac * this->ft.get_auxg_data()[this->ig2isz[ig]]; + out[ig] += tmpfac * this->fft_bundle.get_auxg_data()[this->ig2isz[ig]]; } } else @@ -55,7 +56,7 @@ void PW_Basis::real2recip(const std::complex* in, #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] = tmpfac * this->ft.get_auxg_data()[this->ig2isz[ig]]; + out[ig] = tmpfac * this->fft_bundle.get_auxg_data()[this->ig2isz[ig]]; } } ModuleBase::timer::tick(this->classname, "real2recip"); @@ -82,11 +83,11 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo { for(int ipy = 0 ; ipy < npy ; ++ipy) { - this->ft.get_rspace_data()[ix*npy + ipy] = in[ix*npy + ipy]; + this->fft_bundle.get_rspace_data()[ix*npy + ipy] = in[ix*npy + ipy]; } } - this->ft.fftxyr2c(ft.get_rspace_data(),ft.get_auxr_data()); + this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data(),fft_bundle.get_auxr_data()); } else { @@ -95,13 +96,13 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - this->ft.get_auxr_data()[ir] = std::complex(in[ir],0); + this->fft_bundle.get_auxr_data()[ir] = std::complex(in[ir],0); } - this->ft.fftxyfor(ft.get_auxr_data(),ft.get_auxr_data()); + this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(),fft_bundle.get_auxr_data()); } - this->gatherp_scatters(this->ft.get_auxr_data(), this->ft.get_auxg_data()); + this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); - this->ft.fftzfor(ft.get_auxg_data(),ft.get_auxg_data()); + this->fft_bundle.fftzfor(fft_bundle.get_auxg_data(),fft_bundle.get_auxg_data()); if(add) { @@ -111,7 +112,7 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] += tmpfac * this->ft.get_auxg_data()[this->ig2isz[ig]]; + out[ig] += tmpfac * this->fft_bundle.get_auxg_data()[this->ig2isz[ig]]; } } else @@ -122,7 +123,7 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] = tmpfac * this->ft.get_auxg_data()[this->ig2isz[ig]]; + out[ig] = tmpfac * this->fft_bundle.get_auxg_data()[this->ig2isz[ig]]; } } ModuleBase::timer::tick(this->classname, "real2recip"); @@ -148,7 +149,7 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int i = 0 ; i < this->nst * this->nz ; ++i) { - ft.get_auxg_data()[i] = std::complex(0, 0); + fft_bundle.get_auxg_data()[i] = std::complex(0, 0); } #ifdef _OPENMP @@ -156,13 +157,13 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - this->ft.get_auxg_data()[this->ig2isz[ig]] = in[ig]; + this->fft_bundle.get_auxg_data()[this->ig2isz[ig]] = in[ig]; } - this->ft.fftzbac(ft.get_auxg_data(), ft.get_auxg_data()); + this->fft_bundle.fftzbac(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); - this->gathers_scatterp(this->ft.get_auxg_data(),this->ft.get_auxr_data()); + this->gathers_scatterp(this->fft_bundle.get_auxg_data(),this->fft_bundle.get_auxr_data()); - this->ft.fftxybac(ft.get_auxr_data(),ft.get_auxr_data()); + this->fft_bundle.fftxybac(fft_bundle.get_auxr_data(),fft_bundle.get_auxr_data()); if(add) { @@ -171,7 +172,7 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] += factor * this->ft.get_auxr_data()[ir]; + out[ir] += factor * this->fft_bundle.get_auxr_data()[ir]; } } else @@ -181,7 +182,7 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] = this->ft.get_auxr_data()[ir]; + out[ir] = this->fft_bundle.get_auxr_data()[ir]; } } ModuleBase::timer::tick(this->classname, "recip2real"); @@ -203,7 +204,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int i = 0 ; i < this->nst * this->nz ; ++i) { - ft.get_auxg_data()[i] = std::complex(0, 0); + fft_bundle.get_auxg_data()[i] = std::complex(0, 0); } #ifdef _OPENMP @@ -211,15 +212,15 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - this->ft.get_auxg_data()[this->ig2isz[ig]] = in[ig]; + this->fft_bundle.get_auxg_data()[this->ig2isz[ig]] = in[ig]; } - this->ft.fftzbac(ft.get_auxg_data(), ft.get_auxg_data()); + this->fft_bundle.fftzbac(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); - this->gathers_scatterp(this->ft.get_auxg_data(), this->ft.get_auxr_data()); + this->gathers_scatterp(this->fft_bundle.get_auxg_data(), this->fft_bundle.get_auxr_data()); if(this->gamma_only) { - this->ft.fftxyc2r(ft.get_auxr_data(),ft.get_rspace_data()); + this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data(),fft_bundle.get_rspace_data()); // r2c in place const int npy = this->ny * this->nplane; @@ -233,7 +234,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo { for(int ipy = 0 ; ipy < npy ; ++ipy) { - out[ix*npy + ipy] += factor * this->ft.get_rspace_data()[ix*npy + ipy]; + out[ix*npy + ipy] += factor * this->fft_bundle.get_rspace_data()[ix*npy + ipy]; } } } @@ -246,14 +247,14 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo { for(int ipy = 0 ; ipy < npy ; ++ipy) { - out[ix*npy + ipy] = this->ft.get_rspace_data()[ix*npy + ipy]; + out[ix*npy + ipy] = this->fft_bundle.get_rspace_data()[ix*npy + ipy]; } } } } else { - this->ft.fftxybac(ft.get_auxr_data(),ft.get_auxr_data()); + this->fft_bundle.fftxybac(fft_bundle.get_auxr_data(),fft_bundle.get_auxr_data()); if(add) { #ifdef _OPENMP @@ -261,7 +262,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] += factor * this->ft.get_auxr_data()[ir].real(); + out[ir] += factor * this->fft_bundle.get_auxr_data()[ir].real(); } } else @@ -271,7 +272,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] = this->ft.get_auxr_data()[ir].real(); + out[ir] = this->fft_bundle.get_auxr_data()[ir].real(); } } } diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index 0ea362825b..88285df119 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -32,7 +32,7 @@ void PW_Basis_K::real2recip(const std::complex* in, ModuleBase::timer::tick(this->classname, "real2recip"); assert(this->gamma_only == false); - auto* auxr = this->ft.get_auxr_data(); + auto* auxr = this->fft_bundle.get_auxr_data(); #ifdef _OPENMP #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif @@ -40,15 +40,15 @@ void PW_Basis_K::real2recip(const std::complex* in, { auxr[ir] = in[ir]; } - this->ft.fftxyfor(ft.get_auxr_data(), ft.get_auxr_data()); + this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); - this->gatherp_scatters(this->ft.get_auxr_data(), this->ft.get_auxg_data()); + this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); - this->ft.fftzfor(ft.get_auxg_data(), ft.get_auxg_data()); + this->fft_bundle.fftzfor(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft.get_auxg_data(); + auto* auxg = this->fft_bundle.get_auxg_data(); if (add) { FPTYPE tmpfac = factor / FPTYPE(this->nxyz); @@ -98,7 +98,7 @@ void PW_Basis_K::real2recip(const FPTYPE* in, assert(this->gamma_only == true); // for(int ir = 0 ; ir < this->nrxx ; ++ir) // { - // this->ft.get_rspace_data()[ir] = in[ir]; + // this->fft_bundle.get_rspace_data()[ir] = in[ir]; // } // r2c in place const int npy = this->ny * this->nplane; @@ -109,19 +109,19 @@ void PW_Basis_K::real2recip(const FPTYPE* in, { for (int ipy = 0; ipy < npy; ++ipy) { - this->ft.get_rspace_data()[ix * npy + ipy] = in[ix * npy + ipy]; + this->fft_bundle.get_rspace_data()[ix * npy + ipy] = in[ix * npy + ipy]; } } - this->ft.fftxyr2c(ft.get_rspace_data(), ft.get_auxr_data()); + this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data(), fft_bundle.get_auxr_data()); - this->gatherp_scatters(this->ft.get_auxr_data(), this->ft.get_auxg_data()); + this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); - this->ft.fftzfor(ft.get_auxg_data(), ft.get_auxg_data()); + this->fft_bundle.fftzfor(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft.get_auxg_data(); + auto* auxg = this->fft_bundle.get_auxg_data(); if (add) { FPTYPE tmpfac = factor / FPTYPE(this->nxyz); @@ -170,11 +170,11 @@ void PW_Basis_K::recip2real(const std::complex* in, { ModuleBase::timer::tick(this->classname, "recip2real"); assert(this->gamma_only == false); - ModuleBase::GlobalFunc::ZEROS(ft.get_auxg_data(), this->nst * this->nz); + ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data(), this->nst * this->nz); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft.get_auxg_data(); + auto* auxg = this->fft_bundle.get_auxg_data(); #ifdef _OPENMP #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif @@ -182,13 +182,13 @@ void PW_Basis_K::recip2real(const std::complex* in, { auxg[this->igl2isz_k[igl + startig]] = in[igl]; } - this->ft.fftzbac(ft.get_auxg_data(), ft.get_auxg_data()); + this->fft_bundle.fftzbac(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); - this->gathers_scatterp(this->ft.get_auxg_data(), this->ft.get_auxr_data()); + this->gathers_scatterp(this->fft_bundle.get_auxg_data(), this->fft_bundle.get_auxr_data()); - this->ft.fftxybac(ft.get_auxr_data(), ft.get_auxr_data()); + this->fft_bundle.fftxybac(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); - auto* auxr = this->ft.get_auxr_data(); + auto* auxr = this->fft_bundle.get_auxr_data(); if (add) { #ifdef _OPENMP @@ -234,11 +234,11 @@ void PW_Basis_K::recip2real(const std::complex* in, { ModuleBase::timer::tick(this->classname, "recip2real"); assert(this->gamma_only == true); - ModuleBase::GlobalFunc::ZEROS(ft.get_auxg_data(), this->nst * this->nz); + ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data(), this->nst * this->nz); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft.get_auxg_data(); + auto* auxg = this->fft_bundle.get_auxg_data(); #ifdef _OPENMP #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif @@ -246,20 +246,20 @@ void PW_Basis_K::recip2real(const std::complex* in, { auxg[this->igl2isz_k[igl + startig]] = in[igl]; } - this->ft.fftzbac(ft.get_auxg_data(), ft.get_auxg_data()); + this->fft_bundle.fftzbac(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); - this->gathers_scatterp(this->ft.get_auxg_data(), this->ft.get_auxr_data()); + this->gathers_scatterp(this->fft_bundle.get_auxg_data(), this->fft_bundle.get_auxr_data()); - this->ft.fftxyc2r(ft.get_auxr_data(), ft.get_rspace_data()); + this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data(), fft_bundle.get_rspace_data()); // for(int ir = 0 ; ir < this->nrxx ; ++ir) // { - // out[ir] = this->ft.get_rspace_data()[ir] / this->nxyz; + // out[ir] = this->fft_bundle.get_rspace_data()[ir] / this->nxyz; // } // r2c in place const int npy = this->ny * this->nplane; - auto* rspace = this->ft.get_rspace_data(); + auto* rspace = this->fft_bundle.get_rspace_data(); if (add) { #ifdef _OPENMP diff --git a/source/module_basis/module_pw/test/CMakeLists.txt b/source/module_basis/module_pw/test/CMakeLists.txt index e1ce122d07..8b62dacab0 100644 --- a/source/module_basis/module_pw/test/CMakeLists.txt +++ b/source/module_basis/module_pw/test/CMakeLists.txt @@ -4,7 +4,7 @@ AddTest( LIBS parameter ${math_libs} planewave device SOURCES ../../../module_base/matrix.cpp ../../../module_base/complexmatrix.cpp ../../../module_base/matrix3.cpp ../../../module_base/tool_quit.cpp ../../../module_base/mymath.cpp ../../../module_base/timer.cpp ../../../module_base/memory.cpp ../../../module_base/blas_connector.cpp - ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp + ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp # ../../../module_psi/kernels/psi_memory_op.cpp ../../../module_base/module_device/memory_op.cpp depend_mock.cpp pw_test.cpp test1-1-1.cpp test1-1-2.cpp test1-2.cpp test1-3.cpp test1-4.cpp test1-5.cpp diff --git a/source/module_basis/module_pw/test/Makefile b/source/module_basis/module_pw/test/Makefile index d970a637db..884f0f74c0 100644 --- a/source/module_basis/module_pw/test/Makefile +++ b/source/module_basis/module_pw/test/Makefile @@ -2,7 +2,7 @@ # Please set # e.g. make CXX=mpiicpc or make CXX=icpc #====================================================================== -CXX = mpiicpc +CXX = mpiicpx # mpiicpc: compile intel parallel version # icpc: compile intel sequential version # mpicxx: compile gnu parallel version @@ -25,7 +25,7 @@ GTEST_DIR = /home/qianrui/gnucompile/g_gtest # Compiler information #========================== HONG = -D__NORMAL -INCLUDES = -I. -I../../../ +INCLUDES = -I. -I../../../ -I../../../module_base/module_container LIBS = OPTS = -Ofast -march=native -std=c++11 -m64 ${INCLUDES} OBJ_DIR = obj @@ -103,7 +103,10 @@ GTESTOPTS = -I${GTEST_DIR}/include -L${GTEST_DIR}/lib -lgtest -lpthread #========================== VPATH=../../../module_base\ ../../../module_base/module_device\ -:../ +../../../module_base/module_container/ATen/core\ +../../../module_base/module_container/ATen\ +../../../module_parameter\ +../\ MATH_OBJS0=matrix.o\ matrix3.o\ @@ -123,7 +126,12 @@ pw_basis_sup.o\ pw_transform_k.o\ memory.o\ memory_op.o\ -depend_mock.o +depend_mock.o\ +parameter.o\ +fft_base.o\ +fft_bundle.o\ +fft_cpu.o\ + OTHER_OBJS0= diff --git a/source/module_basis/module_pw/test_serial/CMakeLists.txt b/source/module_basis/module_pw/test_serial/CMakeLists.txt index df9ae6a962..028d5b3a0e 100644 --- a/source/module_basis/module_pw/test_serial/CMakeLists.txt +++ b/source/module_basis/module_pw/test_serial/CMakeLists.txt @@ -10,6 +10,9 @@ add_library( planewave_serial OBJECT ../fft.cpp + ../module_fft/fft_base.cpp + ../module_fft/fft_bundle.cpp + ../module_fft/fft_cpu.cpp ../pw_basis.cpp ../pw_basis_k.cpp ../pw_basis_sup.cpp diff --git a/source/module_elecstate/module_charge/charge.cpp b/source/module_elecstate/module_charge/charge.cpp index 9003844ca1..6a2405b579 100644 --- a/source/module_elecstate/module_charge/charge.cpp +++ b/source/module_elecstate/module_charge/charge.cpp @@ -644,10 +644,10 @@ void Charge::atomic_rho(const int spin_number_need, double sumrea = 0.0; for (int ir = 0; ir < this->rhopw->nrxx; ir++) { - rea = this->rhopw->ft.get_auxr_data()[ir].real(); + rea = this->rhopw->fft_bundle.get_auxr_data()[ir].real(); sumrea += rea; neg += std::min(0.0, rea); - ima += std::abs(this->rhopw->ft.get_auxr_data()[ir].imag()); + ima += std::abs(this->rhopw->fft_bundle.get_auxr_data()[ir].imag()); } #ifdef __MPI diff --git a/source/module_elecstate/module_charge/charge_init.cpp b/source/module_elecstate/module_charge/charge_init.cpp index c0806da937..57af45e3be 100644 --- a/source/module_elecstate/module_charge/charge_init.cpp +++ b/source/module_elecstate/module_charge/charge_init.cpp @@ -260,8 +260,8 @@ void Charge::set_rho_core( double rhoneg = 0.0; for (int ir = 0; ir < this->rhopw->nrxx; ir++) { - rhoneg += std::min(0.0, this->rhopw->ft.get_auxr_data()[ir].real()); - rhoima += std::abs(this->rhopw->ft.get_auxr_data()[ir].imag()); + rhoneg += std::min(0.0, this->rhopw->fft_bundle.get_auxr_data()[ir].real()); + rhoima += std::abs(this->rhopw->fft_bundle.get_auxr_data()[ir].imag()); // NOTE: Core charge is computed in reciprocal space and brought to real // space by FFT. For non smooth core charges (or insufficient cut-off) // this may result in negative values in some grid points. diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index 4a0950fb30..ea69f172df 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -56,6 +56,7 @@ ModulePW::FFT::FFT() ModulePW::FFT::~FFT() { } + void ModulePW::PW_Basis::initgrids(double, ModuleBase::Matrix3, double) { } diff --git a/source/module_esolver/esolver_fp.cpp b/source/module_esolver/esolver_fp.cpp index 453b14cd63..4a19c1d917 100644 --- a/source/module_esolver/esolver_fp.cpp +++ b/source/module_esolver/esolver_fp.cpp @@ -84,6 +84,7 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell) this->pw_rho->initparameters(false, 4.0 * inp.ecutwfc); this->pw_rho->ft.fft_mode = inp.fft_mode; + this->pw_rho->fft_bundle.initfftmode(inp.fft_mode); this->pw_rho->setuptransform(); this->pw_rho->collect_local_pw(); this->pw_rho->collect_uniqgg(); @@ -109,6 +110,7 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell) } this->pw_rhod->initparameters(false, inp.ecutrho); this->pw_rhod->ft.fft_mode = inp.fft_mode; + this->pw_rhod->fft_bundle.initfftmode(inp.fft_mode); pw_rhod_sup->setuptransform(this->pw_rho); this->pw_rhod->collect_local_pw(); this->pw_rhod->collect_uniqgg(); diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index 8f9a715d5a..7cf729958a 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -247,7 +247,7 @@ void ESolver_KS::before_all_runners(const Input_para& inp, UnitCell& #endif this->pw_wfc->ft.fft_mode = inp.fft_mode; - + this->pw_wfc->fft_bundle.initfftmode(inp.fft_mode); this->pw_wfc->setuptransform(); //! 9) initialize the number of plane waves for each k point diff --git a/source/module_hamilt_general/module_xc/test/CMakeLists.txt b/source/module_hamilt_general/module_xc/test/CMakeLists.txt index 7466f40a92..66cf5f9cb0 100644 --- a/source/module_hamilt_general/module_xc/test/CMakeLists.txt +++ b/source/module_hamilt_general/module_xc/test/CMakeLists.txt @@ -38,6 +38,9 @@ AddTest( ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp ../../../module_base/blas_connector.cpp + ../../../module_basis/module_pw/module_fft/fft_base.cpp + ../../../module_basis/module_pw/module_fft/fft_bundle.cpp + ../../../module_basis/module_pw/module_fft/fft_cpu.cpp ) AddTest( @@ -73,4 +76,7 @@ AddTest( ../../../module_base/timer.cpp ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp + ../../../module_basis/module_pw/module_fft/fft_base.cpp + ../../../module_basis/module_pw/module_fft/fft_bundle.cpp + ../../../module_basis/module_pw/module_fft/fft_cpu.cpp ) \ No newline at end of file diff --git a/source/module_hsolver/test/hsolver_pw_sup.h b/source/module_hsolver/test/hsolver_pw_sup.h index 0fc0e72eaa..c70025a2c2 100644 --- a/source/module_hsolver/test/hsolver_pw_sup.h +++ b/source/module_hsolver/test/hsolver_pw_sup.h @@ -4,7 +4,6 @@ namespace ModulePW { PW_Basis::PW_Basis(){}; PW_Basis::~PW_Basis(){}; - void PW_Basis::initgrids( const double lat0_in, // unit length (unit in bohr) const ModuleBase::Matrix3