From c73bd57b91021e021cd97a8518985ac837a20ecc Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 11:46:33 +0800 Subject: [PATCH 01/27] add the basic func of the file --- source/module_base/fft/fft_term.cpp | 280 ++++++++++++++++++ source/module_basis/module_pw/CMakeLists.txt | 2 + source/module_basis/module_pw/fft_base.cpp | 43 +++ source/module_basis/module_pw/fft_base.h | 80 +++++ source/module_basis/module_pw/fft_temp.cpp | 280 ++++++++++++++++++ source/module_basis/module_pw/fft_temp.h | 66 +++++ source/module_basis/module_pw/pw_basis.h | 1 + .../module_basis/module_pw/pw_transform.cpp | 1 + 8 files changed, 753 insertions(+) create mode 100644 source/module_base/fft/fft_term.cpp create mode 100644 source/module_basis/module_pw/fft_base.cpp create mode 100644 source/module_basis/module_pw/fft_base.h create mode 100644 source/module_basis/module_pw/fft_temp.cpp create mode 100644 source/module_basis/module_pw/fft_temp.h diff --git a/source/module_base/fft/fft_term.cpp b/source/module_base/fft/fft_term.cpp new file mode 100644 index 0000000000..bce534e2b6 --- /dev/null +++ b/source/module_base/fft/fft_term.cpp @@ -0,0 +1,280 @@ +#include +#include "fft_temp.h" +// #include "fft_cpu.h" +#if defined(__CUDA) +#include "fft_cuda.h" +#endif +#if defined(__ROCM) +#include "fft_rcom.h" +#endif +#include "module_base/module_device/device.h" +// #include "fft_gpu.h" +FFT1::FFT1() +{ + fft_float = nullptr; + fft_double = nullptr; +} +FFT1::FFT1(std::string device_in,std::string precision_in) +{ + assert(device_in=="cpu" || device_in=="gpu"); + assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); + this->device = device_in; + this->precision = precision_in; + if (device=="cpu") + { + fft_float = new FFT_CPU(); + fft_double = new FFT_CPU(); + } + else if (device=="gpu") + { + #if defined(__ROCM) + fft_float = new FFT_RCOM(); + fft_double = new FFT_RCOM(); + #elif defined(__CUDA) + fft_float = new FFT_CUDA(); + fft_double = new FFT_CUDA(); + #endif + } +} + +FFT1::~FFT1() +{ + if (fft_float!=nullptr) + { + delete fft_float; + fft_float=nullptr; + } + if (fft_double!=nullptr) + { + delete fft_double; + fft_double=nullptr; + } +} + +void FFT1::set_device(std::string device_in) +{ + this->device = device_in; +} + +void FFT1::set_precision(std::string precision_in) +{ + this->precision = precision_in; +} +void FFT1::setfft(std::string device_in,std::string precision_in) +{ + assert(device_in=="cpu" || device_in=="gpu"); + assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); + this->device = device_in; + this->precision = precision_in; + if (device=="cpu") + { + fft_float = new FFT_CPU(); + fft_double = new FFT_CPU(); + } + else if (device=="gpu") + { + #if defined(__ROCM) + fft_float = new FFT_RCOM(); + fft_double = new FFT_RCOM(); + #elif defined(__CUDA) + fft_float = new FFT_CUDA(); + fft_double = new FFT_CUDA(); + #endif + } +} +void FFT1::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) +{ + if (this->precision=="single") + { + float_flag = 1; + } + else if (this->precision=="double") + { + double_flag = 1; + } + else if (this->precision=="mixing") + { + float_flag = 1; + double_flag = 1; + } + if (float_flag) + { + fft_float->initfftmode(this->fft_mode); + fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); + } + if (double_flag) + { + fft_double->initfftmode(this->fft_mode); + fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); + } +} +void FFT1::initfftmode(int fft_mode_in) +{ + this->fft_mode = fft_mode_in; +} + +void FFT1::setupFFT() +{ + if (double_flag) + { + fft_double->setupFFT(); + } + if (float_flag) + { + fft_float->setupFFT(); + } +} + +void FFT1::clearFFT() +{ + if (double_flag) + { + fft_double->cleanFFT(); + } + if (float_flag) + { + fft_float->cleanFFT(); + } +} +void FFT1::clear() +{ + this->clearFFT(); + if (float_flag) + { + fft_float->clear(); + } + if (double_flag) + { + fft_double->clear(); + } +} +// access the real space data +template <> +float* FFT1::get_rspace_data() const +{ + return fft_float->get_rspace_data(); +} + +template <> +double* FFT1::get_rspace_data() const +{ + return fft_double->get_rspace_data(); +} +template <> +std::complex* FFT1::get_auxr_data() const +{ + return fft_float->get_auxr_data(); +} +template <> +std::complex* FFT1::get_auxr_data() const +{ + return fft_double->get_auxr_data(); +} +template <> +std::complex* FFT1::get_auxg_data() const +{ + return fft_float->get_auxg_data(); +} +template <> +std::complex* FFT1::get_auxg_data() const +{ + return fft_double->get_auxg_data(); +} +template <> +std::complex* FFT1::get_auxr_3d_data() const +{ + return fft_float->get_auxr_3d_data(); +} +template <> +std::complex* FFT1::get_auxr_3d_data() const +{ + return fft_double->get_auxr_3d_data(); +} +template <> +void FFT1::fftxyfor(std::complex* in, std::complex* out) const +{ + fft_float->fftxyfor(in,out); +} + +template <> +void FFT1::fftxyfor(std::complex* in, std::complex* out) const +{ + fft_double->fftxyfor(in,out); +} + +template <> +void FFT1::fftzfor(std::complex* in, std::complex* out) const +{ + fft_float->fftzfor(in,out); +} +template <> +void FFT1::fftzfor(std::complex* in, std::complex* out) const +{ + fft_double->fftzfor(in,out); +} + +template <> +void FFT1::fftxybac(std::complex* in, std::complex* out) const +{ + fft_float->fftxybac(in,out); +} +template <> +void FFT1::fftxybac(std::complex* in, std::complex* out) const +{ + fft_double->fftxybac(in,out); +} + +template <> +void FFT1::fftzbac(std::complex* in, std::complex* out) const +{ + fft_float->fftzbac(in,out); +} +template <> +void FFT1::fftzbac(std::complex* in, std::complex* out) const +{ + fft_double->fftzbac(in,out); +} +template <> +void FFT1::fftxyr2c(float* in, std::complex* out) const +{ + fft_float->fftxyr2c(in,out); +} +template <> +void FFT1::fftxyr2c(double* in, std::complex* out) const +{ + fft_double->fftxyr2c(in,out); +} + +template <> +void FFT1::fftxyc2r(std::complex* in, float* out) const +{ + fft_float->fftxyc2r(in,out); +} +template <> +void FFT1::fftxyc2r(std::complex* in, double* out) const +{ + fft_double->fftxyc2r(in,out); +} + +template <> +void FFT1::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +{ + fft_float->fft3D_forward(in, out); +} + +template <> +void FFT1::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +{ + fft_double->fft3D_forward(in, out); +} +template <> +void FFT1::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +{ + fft_float->fft3D_backward(in, out); +} +template <> +void FFT1::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +{ + fft_double->fft3D_backward(in, out); +} \ No newline at end of file diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index 2b2d897206..abfde45a34 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -1,5 +1,7 @@ list(APPEND objects fft.cpp + fft_base.cpp + fft_temp.cpp pw_basis.cpp pw_basis_k.cpp pw_basis_sup.cpp diff --git a/source/module_basis/module_pw/fft_base.cpp b/source/module_basis/module_pw/fft_base.cpp new file mode 100644 index 0000000000..31fb9881e1 --- /dev/null +++ b/source/module_basis/module_pw/fft_base.cpp @@ -0,0 +1,43 @@ +#include "fft_base.h" +template +FFT_BASE::FFT_BASE() +{ +} +template +FFT_BASE::~FFT_BASE() +{ + +} +template +void FFT_BASE::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + int nproc_in, bool gamma_only_in, bool xprime_in, bool mpifft_in) +{ + this->gamma_only = gamma_only_in; + this->xprime = xprime_in; + this->fftnx = this->nx = nx_in; + this->fftny = this->ny = ny_in; + if (this->gamma_only) + { + if (xprime) + this->fftnx = int(nx / 2) + 1; + else + this->fftny = int(ny / 2) + 1; + } + this->nz = nz_in; + this->ns = ns_in; + this->lixy = lixy_in; + this->rixy = rixy_in; + this->nplane = nplane_in; + this->nproc = nproc_in; + this->mpifft = mpifft_in; + this->nxy = this->nx * this->ny; + this->fftnxy = this->fftnx * this->fftny; + const int nrxx = this->nxy * this->nplane; + const int nsz = this->nz * this->ns; + this->maxgrids = (nsz > nrxx) ? nsz : nrxx; +} + +template FFT_BASE::FFT_BASE(); +template FFT_BASE::FFT_BASE(); +template FFT_BASE::~FFT_BASE(); +template FFT_BASE::~FFT_BASE(); \ No newline at end of file diff --git a/source/module_basis/module_pw/fft_base.h b/source/module_basis/module_pw/fft_base.h new file mode 100644 index 0000000000..6dd3d947f2 --- /dev/null +++ b/source/module_basis/module_pw/fft_base.h @@ -0,0 +1,80 @@ +#include +#include +#include "fftw3.h" +#ifndef FFT_BASE_H +#define FFT_BASE_H +template +class FFT_BASE +{ +public: + + __attribute__((weak)) FFT_BASE(); + virtual __attribute__((weak)) ~FFT_BASE(); + + // init parameters of fft + virtual void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false); + + virtual __attribute__((weak)) void initfftmode(int fft_mode_in); + + //init fftw_plans + virtual void setupFFT()=0; + + //destroy fftw_plans + virtual void cleanFFT()=0; + //clear fftw_data + virtual void clear()=0; + + // access the real space data + virtual __attribute__((weak)) FPTYPE* get_rspace_data() const; + + virtual __attribute__((weak)) std::complex* get_auxr_data() const; + + virtual __attribute__((weak)) std::complex* get_auxg_data() const; + + virtual __attribute__((weak)) std::complex* get_auxr_3d_data() const; + + //forward fft in x-y direction + virtual __attribute__((weak)) void fftxyfor(std::complex* in, std::complex* out) const; + + virtual __attribute__((weak)) void fftxybac(std::complex* in, std::complex* out) const; + + virtual __attribute__((weak)) void fftzfor(std::complex* in, std::complex* out) const; + + virtual __attribute__((weak)) void fftzbac(std::complex* in, std::complex* out) const; + + virtual __attribute__((weak)) void fftxyr2c(FPTYPE* in, std::complex* out) const; + + virtual __attribute__((weak)) void fftxyc2r(std::complex* in, FPTYPE* out) const; + + virtual __attribute__((weak)) void fft3D_forward(std::complex* in, std::complex* out) const; + + virtual __attribute__((weak)) void fft3D_backward(std::complex* in, std::complex* out) const; + +protected: + int initflag = 0; // 0: not initialized; 1: initialized + int fftnx=0; + int fftny=0; + int fftnxy=0; + int ny=0; + int nx=0; + int nz=0; + int nxy=0; + int nplane=0; //number of x-y planes + bool gamma_only = false; + int lixy=0; + int rixy=0;// lixy: the left edge of the pw ball in the y direction; rixy: the right edge of the pw ball in the x or y direction + bool mpifft = false; // if use mpi fft, only used when define __FFTW3_MPI + int maxgrids = 0; // maxgrids = (nsz > nrxx) ? nsz : nrxx; + bool xprime = true; // true: when do recip2real, x-fft will be done last and when doing real2recip, x-fft will be done first; false: y-fft + // For gamma_only, true: we use half x; false: we use half y + int ns=0; //number of sticks + int nproc=1; // number of proc. + int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive + +public: + void set_device(std::string device_); + void set_precision(std::string precision_); + +}; +#endif // FFT_BASE_H diff --git a/source/module_basis/module_pw/fft_temp.cpp b/source/module_basis/module_pw/fft_temp.cpp new file mode 100644 index 0000000000..5c4178bb58 --- /dev/null +++ b/source/module_basis/module_pw/fft_temp.cpp @@ -0,0 +1,280 @@ +#include +#include "fft_temp.h" +// #include "fft_cpu.h" +// #if defined(__CUDA) +// #include "fft_cuda.h" +// #endif +// #if defined(__ROCM) +// #include "fft_rcom.h" +// #endif +// #include "module_base/module_device/device.h" +// #include "fft_gpu.h" +FFT1::FFT1() +{ + fft_float = nullptr; + fft_double = nullptr; +} +FFT1::FFT1(std::string device_in,std::string precision_in) +{ + assert(device_in=="cpu" || device_in=="gpu"); + assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); + this->device = device_in; + this->precision = precision_in; + // if (device=="cpu") + // { + // fft_float = new FFT_CPU(); + // fft_double = new FFT_CPU(); + // } + // else if (device=="gpu") + // { + // #if defined(__ROCM) + // fft_float = new FFT_RCOM(); + // fft_double = new FFT_RCOM(); + // #elif defined(__CUDA) + // fft_float = new FFT_CUDA(); + // fft_double = new FFT_CUDA(); + // #endif + // } +} + +FFT1::~FFT1() +{ + if (fft_float!=nullptr) + { + delete fft_float; + fft_float=nullptr; + } + if (fft_double!=nullptr) + { + delete fft_double; + fft_double=nullptr; + } +} + +// void FFT1::set_device(std::string device_in) +// { +// this->device = device_in; +// } + +// void FFT1::set_precision(std::string precision_in) +// { +// this->precision = precision_in; +// } +// void FFT1::setfft(std::string device_in,std::string precision_in) +// { +// assert(device_in=="cpu" || device_in=="gpu"); +// assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); +// this->device = device_in; +// this->precision = precision_in; +// if (device=="cpu") +// { +// fft_float = new FFT_CPU(); +// fft_double = new FFT_CPU(); +// } +// else if (device=="gpu") +// { +// #if defined(__ROCM) +// fft_float = new FFT_RCOM(); +// fft_double = new FFT_RCOM(); +// #elif defined(__CUDA) +// fft_float = new FFT_CUDA(); +// fft_double = new FFT_CUDA(); +// #endif +// } +// } +// void FFT1::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, +// int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) +// { +// if (this->precision=="single") +// { +// float_flag = 1; +// } +// else if (this->precision=="double") +// { +// double_flag = 1; +// } +// else if (this->precision=="mixing") +// { +// float_flag = 1; +// double_flag = 1; +// } +// if (float_flag) +// { +// fft_float->initfftmode(this->fft_mode); +// fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); +// } +// if (double_flag) +// { +// fft_double->initfftmode(this->fft_mode); +// fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); +// } +// } +// void FFT1::initfftmode(int fft_mode_in) +// { +// this->fft_mode = fft_mode_in; +// } + +// void FFT1::setupFFT() +// { +// if (double_flag) +// { +// fft_double->setupFFT(); +// } +// if (float_flag) +// { +// fft_float->setupFFT(); +// } +// } + +// void FFT1::clearFFT() +// { +// if (double_flag) +// { +// fft_double->cleanFFT(); +// } +// if (float_flag) +// { +// fft_float->cleanFFT(); +// } +// } +// void FFT1::clear() +// { +// this->clearFFT(); +// if (float_flag) +// { +// fft_float->clear(); +// } +// if (double_flag) +// { +// fft_double->clear(); +// } +// } +// // access the real space data +// template <> +// float* FFT1::get_rspace_data() const +// { +// return fft_float->get_rspace_data(); +// } + +// template <> +// double* FFT1::get_rspace_data() const +// { +// return fft_double->get_rspace_data(); +// } +// template <> +// std::complex* FFT1::get_auxr_data() const +// { +// return fft_float->get_auxr_data(); +// } +// template <> +// std::complex* FFT1::get_auxr_data() const +// { +// return fft_double->get_auxr_data(); +// } +// template <> +// std::complex* FFT1::get_auxg_data() const +// { +// return fft_float->get_auxg_data(); +// } +// template <> +// std::complex* FFT1::get_auxg_data() const +// { +// return fft_double->get_auxg_data(); +// } +// template <> +// std::complex* FFT1::get_auxr_3d_data() const +// { +// return fft_float->get_auxr_3d_data(); +// } +// template <> +// std::complex* FFT1::get_auxr_3d_data() const +// { +// return fft_double->get_auxr_3d_data(); +// } +// template <> +// void FFT1::fftxyfor(std::complex* in, std::complex* out) const +// { +// fft_float->fftxyfor(in,out); +// } + +// template <> +// void FFT1::fftxyfor(std::complex* in, std::complex* out) const +// { +// fft_double->fftxyfor(in,out); +// } + +// template <> +// void FFT1::fftzfor(std::complex* in, std::complex* out) const +// { +// fft_float->fftzfor(in,out); +// } +// template <> +// void FFT1::fftzfor(std::complex* in, std::complex* out) const +// { +// fft_double->fftzfor(in,out); +// } + +// template <> +// void FFT1::fftxybac(std::complex* in, std::complex* out) const +// { +// fft_float->fftxybac(in,out); +// } +// template <> +// void FFT1::fftxybac(std::complex* in, std::complex* out) const +// { +// fft_double->fftxybac(in,out); +// } + +// template <> +// void FFT1::fftzbac(std::complex* in, std::complex* out) const +// { +// fft_float->fftzbac(in,out); +// } +// template <> +// void FFT1::fftzbac(std::complex* in, std::complex* out) const +// { +// fft_double->fftzbac(in,out); +// } +// template <> +// void FFT1::fftxyr2c(float* in, std::complex* out) const +// { +// fft_float->fftxyr2c(in,out); +// } +// template <> +// void FFT1::fftxyr2c(double* in, std::complex* out) const +// { +// fft_double->fftxyr2c(in,out); +// } + +// template <> +// void FFT1::fftxyc2r(std::complex* in, float* out) const +// { +// fft_float->fftxyc2r(in,out); +// } +// template <> +// void FFT1::fftxyc2r(std::complex* in, double* out) const +// { +// fft_double->fftxyc2r(in,out); +// } + +// template <> +// void FFT1::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// { +// fft_float->fft3D_forward(in, out); +// } + +// template <> +// void FFT1::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// { +// fft_double->fft3D_forward(in, out); +// } +// template <> +// void FFT1::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// { +// fft_float->fft3D_backward(in, out); +// } +// template <> +// void FFT1::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// { +// fft_double->fft3D_backward(in, out); +// } \ No newline at end of file diff --git a/source/module_basis/module_pw/fft_temp.h b/source/module_basis/module_pw/fft_temp.h new file mode 100644 index 0000000000..be0ef42c62 --- /dev/null +++ b/source/module_basis/module_pw/fft_temp.h @@ -0,0 +1,66 @@ +#include "fft_base.h" +// #include "module_psi/psi.h" +#ifndef FFT1_H +#define FFT1_H +class FFT1 +{ + public: + FFT1(); + FFT1(std::string device_in,std::string precision_in); + ~FFT1(); + + void setfft(std::string device_in,std::string precision_in); + void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false); + + void initfftmode(int fft_mode_in); + + void setupFFT(); + + void clearFFT(); + + void clear(); + + template + FPTYPE* get_rspace_data() const; + template + std::complex* get_auxr_data() const; + template + std::complex* get_auxg_data() const; + template + std::complex* get_auxr_3d_data() const; + + template + void fftzfor(std::complex* in, std::complex* out) const; + template + void fftxyfor(std::complex* in, std::complex* out) const; + template + void fftzbac(std::complex* in, std::complex* out) const; + template + void fftxybac(std::complex* in, std::complex* out) const; + template + void fftxyr2c(FPTYPE* in, std::complex* out) const; + template + void fftxyc2r(std::complex* in, FPTYPE* out) const; + + template + void fft3D_forward(const Device* ctx, std::complex* in, std::complex* out) const; + template + void fft3D_backward(const Device* ctx, std::complex* in, std::complex* out) const; + + void set_device(std::string device_in); + + void set_precision(std::string precision_in); + + private: + int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive + bool float_flag=0; + bool double_flag=0; + FFT_BASE* fft_float=nullptr; + FFT_BASE* fft_double=nullptr; + + std::string device = "cpu"; + std::string precision = "double"; +}; + +#endif // FFT_H \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index 6f95343b1a..1e80cc415b 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -6,6 +6,7 @@ #include "module_base/vector3.h" #include #include "fft.h" +#include "fft_temp.h" #include #ifdef __MPI #include "mpi.h" diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp index ca523ace85..6885f38731 100644 --- a/source/module_basis/module_pw/pw_transform.cpp +++ b/source/module_basis/module_pw/pw_transform.cpp @@ -1,4 +1,5 @@ #include "fft.h" +#include "fft_temp.h" #include #include "pw_basis.h" #include From e8381611383e654cc28ef26ec8e1e69a272a7d55 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 11:48:35 +0800 Subject: [PATCH 02/27] modify the Makefile --- source/Makefile.Objects | 2 ++ source/module_basis/module_pw/pw_basis.h | 1 + 2 files changed, 3 insertions(+) diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 8e6fb95677..4d1f7765c4 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -411,6 +411,8 @@ OBJS_PSI_INITIALIZER=psi_initializer.o\ psi_initializer_nao_random.o\ OBJS_PW=fft.o\ + fft_base.o\ + fft_temp.o\ pw_basis.o\ pw_basis_k.o\ pw_basis_sup.o\ diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index 1e80cc415b..cd50d7ef86 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -243,6 +243,7 @@ class PW_Basis int nmaxgr=0; // Gamma_only: max between npw and (nrxx+1)/2, others: max between npw and nrxx // Thus complex[nmaxgr] is able to contain either reciprocal or real data FFT ft; + FFT1 ft1; //The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform). template From 75e42badadb86699e17543ac1860e9bcbe5eb8e1 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 14:32:04 +0800 Subject: [PATCH 03/27] delete file --- source/module_base/fft/fft_term.cpp | 280 --------------------- source/module_base/module_fft/fft_temp.cpp | 280 +++++++++++++++++++++ source/module_basis/module_pw/fft_temp.cpp | 74 +++--- source/module_basis/module_pw/fft_temp.h | 12 +- source/module_basis/module_pw/pw_basis.h | 2 +- 5 files changed, 324 insertions(+), 324 deletions(-) delete mode 100644 source/module_base/fft/fft_term.cpp create mode 100644 source/module_base/module_fft/fft_temp.cpp diff --git a/source/module_base/fft/fft_term.cpp b/source/module_base/fft/fft_term.cpp deleted file mode 100644 index bce534e2b6..0000000000 --- a/source/module_base/fft/fft_term.cpp +++ /dev/null @@ -1,280 +0,0 @@ -#include -#include "fft_temp.h" -// #include "fft_cpu.h" -#if defined(__CUDA) -#include "fft_cuda.h" -#endif -#if defined(__ROCM) -#include "fft_rcom.h" -#endif -#include "module_base/module_device/device.h" -// #include "fft_gpu.h" -FFT1::FFT1() -{ - fft_float = nullptr; - fft_double = nullptr; -} -FFT1::FFT1(std::string device_in,std::string precision_in) -{ - assert(device_in=="cpu" || device_in=="gpu"); - assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); - this->device = device_in; - this->precision = precision_in; - if (device=="cpu") - { - fft_float = new FFT_CPU(); - fft_double = new FFT_CPU(); - } - else if (device=="gpu") - { - #if defined(__ROCM) - fft_float = new FFT_RCOM(); - fft_double = new FFT_RCOM(); - #elif defined(__CUDA) - fft_float = new FFT_CUDA(); - fft_double = new FFT_CUDA(); - #endif - } -} - -FFT1::~FFT1() -{ - if (fft_float!=nullptr) - { - delete fft_float; - fft_float=nullptr; - } - if (fft_double!=nullptr) - { - delete fft_double; - fft_double=nullptr; - } -} - -void FFT1::set_device(std::string device_in) -{ - this->device = device_in; -} - -void FFT1::set_precision(std::string precision_in) -{ - this->precision = precision_in; -} -void FFT1::setfft(std::string device_in,std::string precision_in) -{ - assert(device_in=="cpu" || device_in=="gpu"); - assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); - this->device = device_in; - this->precision = precision_in; - if (device=="cpu") - { - fft_float = new FFT_CPU(); - fft_double = new FFT_CPU(); - } - else if (device=="gpu") - { - #if defined(__ROCM) - fft_float = new FFT_RCOM(); - fft_double = new FFT_RCOM(); - #elif defined(__CUDA) - fft_float = new FFT_CUDA(); - fft_double = new FFT_CUDA(); - #endif - } -} -void FFT1::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, - int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) -{ - if (this->precision=="single") - { - float_flag = 1; - } - else if (this->precision=="double") - { - double_flag = 1; - } - else if (this->precision=="mixing") - { - float_flag = 1; - double_flag = 1; - } - if (float_flag) - { - fft_float->initfftmode(this->fft_mode); - fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); - } - if (double_flag) - { - fft_double->initfftmode(this->fft_mode); - fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); - } -} -void FFT1::initfftmode(int fft_mode_in) -{ - this->fft_mode = fft_mode_in; -} - -void FFT1::setupFFT() -{ - if (double_flag) - { - fft_double->setupFFT(); - } - if (float_flag) - { - fft_float->setupFFT(); - } -} - -void FFT1::clearFFT() -{ - if (double_flag) - { - fft_double->cleanFFT(); - } - if (float_flag) - { - fft_float->cleanFFT(); - } -} -void FFT1::clear() -{ - this->clearFFT(); - if (float_flag) - { - fft_float->clear(); - } - if (double_flag) - { - fft_double->clear(); - } -} -// access the real space data -template <> -float* FFT1::get_rspace_data() const -{ - return fft_float->get_rspace_data(); -} - -template <> -double* FFT1::get_rspace_data() const -{ - return fft_double->get_rspace_data(); -} -template <> -std::complex* FFT1::get_auxr_data() const -{ - return fft_float->get_auxr_data(); -} -template <> -std::complex* FFT1::get_auxr_data() const -{ - return fft_double->get_auxr_data(); -} -template <> -std::complex* FFT1::get_auxg_data() const -{ - return fft_float->get_auxg_data(); -} -template <> -std::complex* FFT1::get_auxg_data() const -{ - return fft_double->get_auxg_data(); -} -template <> -std::complex* FFT1::get_auxr_3d_data() const -{ - return fft_float->get_auxr_3d_data(); -} -template <> -std::complex* FFT1::get_auxr_3d_data() const -{ - return fft_double->get_auxr_3d_data(); -} -template <> -void FFT1::fftxyfor(std::complex* in, std::complex* out) const -{ - fft_float->fftxyfor(in,out); -} - -template <> -void FFT1::fftxyfor(std::complex* in, std::complex* out) const -{ - fft_double->fftxyfor(in,out); -} - -template <> -void FFT1::fftzfor(std::complex* in, std::complex* out) const -{ - fft_float->fftzfor(in,out); -} -template <> -void FFT1::fftzfor(std::complex* in, std::complex* out) const -{ - fft_double->fftzfor(in,out); -} - -template <> -void FFT1::fftxybac(std::complex* in, std::complex* out) const -{ - fft_float->fftxybac(in,out); -} -template <> -void FFT1::fftxybac(std::complex* in, std::complex* out) const -{ - fft_double->fftxybac(in,out); -} - -template <> -void FFT1::fftzbac(std::complex* in, std::complex* out) const -{ - fft_float->fftzbac(in,out); -} -template <> -void FFT1::fftzbac(std::complex* in, std::complex* out) const -{ - fft_double->fftzbac(in,out); -} -template <> -void FFT1::fftxyr2c(float* in, std::complex* out) const -{ - fft_float->fftxyr2c(in,out); -} -template <> -void FFT1::fftxyr2c(double* in, std::complex* out) const -{ - fft_double->fftxyr2c(in,out); -} - -template <> -void FFT1::fftxyc2r(std::complex* in, float* out) const -{ - fft_float->fftxyc2r(in,out); -} -template <> -void FFT1::fftxyc2r(std::complex* in, double* out) const -{ - fft_double->fftxyc2r(in,out); -} - -template <> -void FFT1::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -{ - fft_float->fft3D_forward(in, out); -} - -template <> -void FFT1::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -{ - fft_double->fft3D_forward(in, out); -} -template <> -void FFT1::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -{ - fft_float->fft3D_backward(in, out); -} -template <> -void FFT1::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -{ - fft_double->fft3D_backward(in, out); -} \ No newline at end of file diff --git a/source/module_base/module_fft/fft_temp.cpp b/source/module_base/module_fft/fft_temp.cpp new file mode 100644 index 0000000000..c7553774d1 --- /dev/null +++ b/source/module_base/module_fft/fft_temp.cpp @@ -0,0 +1,280 @@ +#include +#include "fft_temp.h" +// #include "fft_cpu.h" +// #if defined(__CUDA) +// #include "fft_cuda.h" +// #endif +// #if defined(__ROCM) +// #include "fft_rcom.h" +// #endif +// #include "module_base/module_device/device.h" +// #include "fft_gpu.h" +FFT_TEMP::FFT_TEMP() +{ + fft_float = nullptr; + fft_double = nullptr; +} +FFT_TEMP::FFT_TEMP(std::string device_in,std::string precision_in) +{ + assert(device_in=="cpu" || device_in=="gpu"); + assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); + this->device = device_in; + this->precision = precision_in; + // if (device=="cpu") + // { + // fft_float = new FFT_CPU(); + // fft_double = new FFT_CPU(); + // } + // else if (device=="gpu") + // { + // #if defined(__ROCM) + // fft_float = new FFT_RCOM(); + // fft_double = new FFT_RCOM(); + // #elif defined(__CUDA) + // fft_float = new FFT_CUDA(); + // fft_double = new FFT_CUDA(); + // #endif + // } +} + +FFT_TEMP::~FFT_TEMP() +{ + if (float_flag) + { + delete fft_float; + fft_float=nullptr; + } + if (double_flag) + { + delete fft_double; + fft_double=nullptr; + } +} + +// void FFT_TEMP::set_device(std::string device_in) +// { +// this->device = device_in; +// } + +// void FFT_TEMP::set_precision(std::string precision_in) +// { +// this->precision = precision_in; +// } +// void FFT_TEMP::setfft(std::string device_in,std::string precision_in) +// { +// assert(device_in=="cpu" || device_in=="gpu"); +// assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); +// this->device = device_in; +// this->precision = precision_in; +// if (device=="cpu") +// { +// fft_float = new FFT_CPU(); +// fft_double = new FFT_CPU(); +// } +// else if (device=="gpu") +// { +// #if defined(__ROCM) +// fft_float = new FFT_RCOM(); +// fft_double = new FFT_RCOM(); +// #elif defined(__CUDA) +// fft_float = new FFT_CUDA(); +// fft_double = new FFT_CUDA(); +// #endif +// } +// } +// void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, +// int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) +// { +// if (this->precision=="single") +// { +// float_flag = 1; +// } +// else if (this->precision=="double") +// { +// double_flag = 1; +// } +// else if (this->precision=="mixing") +// { +// float_flag = 1; +// double_flag = 1; +// } +// if (float_flag) +// { +// fft_float->initfftmode(this->fft_mode); +// fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); +// } +// if (double_flag) +// { +// fft_double->initfftmode(this->fft_mode); +// fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); +// } +// } +// void FFT_TEMP::initfftmode(int fft_mode_in) +// { +// this->fft_mode = fft_mode_in; +// } + +// void FFT_TEMP::setupFFT() +// { +// if (double_flag) +// { +// fft_double->setupFFT(); +// } +// if (float_flag) +// { +// fft_float->setupFFT(); +// } +// } + +// void FFT_TEMP::clearFFT() +// { +// if (double_flag) +// { +// fft_double->cleanFFT(); +// } +// if (float_flag) +// { +// fft_float->cleanFFT(); +// } +// } +// void FFT_TEMP::clear() +// { +// this->clearFFT(); +// if (float_flag) +// { +// fft_float->clear(); +// } +// if (double_flag) +// { +// fft_double->clear(); +// } +// } +// // access the real space data +// template <> +// float* FFT_TEMP::get_rspace_data() const +// { +// return fft_float->get_rspace_data(); +// } + +// template <> +// double* FFT_TEMP::get_rspace_data() const +// { +// return fft_double->get_rspace_data(); +// } +// template <> +// std::complex* FFT_TEMP::get_auxr_data() const +// { +// return fft_float->get_auxr_data(); +// } +// template <> +// std::complex* FFT_TEMP::get_auxr_data() const +// { +// return fft_double->get_auxr_data(); +// } +// template <> +// std::complex* FFT_TEMP::get_auxg_data() const +// { +// return fft_float->get_auxg_data(); +// } +// template <> +// std::complex* FFT_TEMP::get_auxg_data() const +// { +// return fft_double->get_auxg_data(); +// } +// template <> +// std::complex* FFT_TEMP::get_auxr_3d_data() const +// { +// return fft_float->get_auxr_3d_data(); +// } +// template <> +// std::complex* FFT_TEMP::get_auxr_3d_data() const +// { +// return fft_double->get_auxr_3d_data(); +// } +// template <> +// void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const +// { +// fft_float->fftxyfor(in,out); +// } + +// template <> +// void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const +// { +// fft_double->fftxyfor(in,out); +// } + +// template <> +// void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const +// { +// fft_float->fftzfor(in,out); +// } +// template <> +// void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const +// { +// fft_double->fftzfor(in,out); +// } + +// template <> +// void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const +// { +// fft_float->fftxybac(in,out); +// } +// template <> +// void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const +// { +// fft_double->fftxybac(in,out); +// } + +// template <> +// void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const +// { +// fft_float->fftzbac(in,out); +// } +// template <> +// void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const +// { +// fft_double->fftzbac(in,out); +// } +// template <> +// void FFT_TEMP::fftxyr2c(float* in, std::complex* out) const +// { +// fft_float->fftxyr2c(in,out); +// } +// template <> +// void FFT_TEMP::fftxyr2c(double* in, std::complex* out) const +// { +// fft_double->fftxyr2c(in,out); +// } + +// template <> +// void FFT_TEMP::fftxyc2r(std::complex* in, float* out) const +// { +// fft_float->fftxyc2r(in,out); +// } +// template <> +// void FFT_TEMP::fftxyc2r(std::complex* in, double* out) const +// { +// fft_double->fftxyc2r(in,out); +// } + +// template <> +// void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// { +// fft_float->fft3D_forward(in, out); +// } + +// template <> +// void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// { +// fft_double->fft3D_forward(in, out); +// } +// template <> +// void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// { +// fft_float->fft3D_backward(in, out); +// } +// template <> +// void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// { +// fft_double->fft3D_backward(in, out); +// } \ No newline at end of file diff --git a/source/module_basis/module_pw/fft_temp.cpp b/source/module_basis/module_pw/fft_temp.cpp index 5c4178bb58..c7553774d1 100644 --- a/source/module_basis/module_pw/fft_temp.cpp +++ b/source/module_basis/module_pw/fft_temp.cpp @@ -9,12 +9,12 @@ // #endif // #include "module_base/module_device/device.h" // #include "fft_gpu.h" -FFT1::FFT1() +FFT_TEMP::FFT_TEMP() { fft_float = nullptr; fft_double = nullptr; } -FFT1::FFT1(std::string device_in,std::string precision_in) +FFT_TEMP::FFT_TEMP(std::string device_in,std::string precision_in) { assert(device_in=="cpu" || device_in=="gpu"); assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); @@ -37,30 +37,30 @@ FFT1::FFT1(std::string device_in,std::string precision_in) // } } -FFT1::~FFT1() +FFT_TEMP::~FFT_TEMP() { - if (fft_float!=nullptr) + if (float_flag) { delete fft_float; fft_float=nullptr; } - if (fft_double!=nullptr) + if (double_flag) { delete fft_double; fft_double=nullptr; } } -// void FFT1::set_device(std::string device_in) +// void FFT_TEMP::set_device(std::string device_in) // { // this->device = device_in; // } -// void FFT1::set_precision(std::string precision_in) +// void FFT_TEMP::set_precision(std::string precision_in) // { // this->precision = precision_in; // } -// void FFT1::setfft(std::string device_in,std::string precision_in) +// void FFT_TEMP::setfft(std::string device_in,std::string precision_in) // { // assert(device_in=="cpu" || device_in=="gpu"); // assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); @@ -82,7 +82,7 @@ FFT1::~FFT1() // #endif // } // } -// void FFT1::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, +// void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, // int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) // { // if (this->precision=="single") @@ -109,12 +109,12 @@ FFT1::~FFT1() // fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); // } // } -// void FFT1::initfftmode(int fft_mode_in) +// void FFT_TEMP::initfftmode(int fft_mode_in) // { // this->fft_mode = fft_mode_in; // } -// void FFT1::setupFFT() +// void FFT_TEMP::setupFFT() // { // if (double_flag) // { @@ -126,7 +126,7 @@ FFT1::~FFT1() // } // } -// void FFT1::clearFFT() +// void FFT_TEMP::clearFFT() // { // if (double_flag) // { @@ -137,7 +137,7 @@ FFT1::~FFT1() // fft_float->cleanFFT(); // } // } -// void FFT1::clear() +// void FFT_TEMP::clear() // { // this->clearFFT(); // if (float_flag) @@ -151,130 +151,130 @@ FFT1::~FFT1() // } // // access the real space data // template <> -// float* FFT1::get_rspace_data() const +// float* FFT_TEMP::get_rspace_data() const // { // return fft_float->get_rspace_data(); // } // template <> -// double* FFT1::get_rspace_data() const +// double* FFT_TEMP::get_rspace_data() const // { // return fft_double->get_rspace_data(); // } // template <> -// std::complex* FFT1::get_auxr_data() const +// std::complex* FFT_TEMP::get_auxr_data() const // { // return fft_float->get_auxr_data(); // } // template <> -// std::complex* FFT1::get_auxr_data() const +// std::complex* FFT_TEMP::get_auxr_data() const // { // return fft_double->get_auxr_data(); // } // template <> -// std::complex* FFT1::get_auxg_data() const +// std::complex* FFT_TEMP::get_auxg_data() const // { // return fft_float->get_auxg_data(); // } // template <> -// std::complex* FFT1::get_auxg_data() const +// std::complex* FFT_TEMP::get_auxg_data() const // { // return fft_double->get_auxg_data(); // } // template <> -// std::complex* FFT1::get_auxr_3d_data() const +// std::complex* FFT_TEMP::get_auxr_3d_data() const // { // return fft_float->get_auxr_3d_data(); // } // template <> -// std::complex* FFT1::get_auxr_3d_data() const +// std::complex* FFT_TEMP::get_auxr_3d_data() const // { // return fft_double->get_auxr_3d_data(); // } // template <> -// void FFT1::fftxyfor(std::complex* in, std::complex* out) const +// void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const // { // fft_float->fftxyfor(in,out); // } // template <> -// void FFT1::fftxyfor(std::complex* in, std::complex* out) const +// void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const // { // fft_double->fftxyfor(in,out); // } // template <> -// void FFT1::fftzfor(std::complex* in, std::complex* out) const +// void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const // { // fft_float->fftzfor(in,out); // } // template <> -// void FFT1::fftzfor(std::complex* in, std::complex* out) const +// void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const // { // fft_double->fftzfor(in,out); // } // template <> -// void FFT1::fftxybac(std::complex* in, std::complex* out) const +// void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const // { // fft_float->fftxybac(in,out); // } // template <> -// void FFT1::fftxybac(std::complex* in, std::complex* out) const +// void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const // { // fft_double->fftxybac(in,out); // } // template <> -// void FFT1::fftzbac(std::complex* in, std::complex* out) const +// void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const // { // fft_float->fftzbac(in,out); // } // template <> -// void FFT1::fftzbac(std::complex* in, std::complex* out) const +// void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const // { // fft_double->fftzbac(in,out); // } // template <> -// void FFT1::fftxyr2c(float* in, std::complex* out) const +// void FFT_TEMP::fftxyr2c(float* in, std::complex* out) const // { // fft_float->fftxyr2c(in,out); // } // template <> -// void FFT1::fftxyr2c(double* in, std::complex* out) const +// void FFT_TEMP::fftxyr2c(double* in, std::complex* out) const // { // fft_double->fftxyr2c(in,out); // } // template <> -// void FFT1::fftxyc2r(std::complex* in, float* out) const +// void FFT_TEMP::fftxyc2r(std::complex* in, float* out) const // { // fft_float->fftxyc2r(in,out); // } // template <> -// void FFT1::fftxyc2r(std::complex* in, double* out) const +// void FFT_TEMP::fftxyc2r(std::complex* in, double* out) const // { // fft_double->fftxyc2r(in,out); // } // template <> -// void FFT1::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const // { // fft_float->fft3D_forward(in, out); // } // template <> -// void FFT1::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const // { // fft_double->fft3D_forward(in, out); // } // template <> -// void FFT1::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const // { // fft_float->fft3D_backward(in, out); // } // template <> -// void FFT1::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +// void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const // { // fft_double->fft3D_backward(in, out); // } \ No newline at end of file diff --git a/source/module_basis/module_pw/fft_temp.h b/source/module_basis/module_pw/fft_temp.h index be0ef42c62..10165fc5c9 100644 --- a/source/module_basis/module_pw/fft_temp.h +++ b/source/module_basis/module_pw/fft_temp.h @@ -1,13 +1,13 @@ #include "fft_base.h" // #include "module_psi/psi.h" -#ifndef FFT1_H -#define FFT1_H -class FFT1 +#ifndef FFT_TEMP_H +#define FFT_TEMP_H +class FFT_TEMP { public: - FFT1(); - FFT1(std::string device_in,std::string precision_in); - ~FFT1(); + FFT_TEMP(); + FFT_TEMP(std::string device_in,std::string precision_in); + ~FFT_TEMP(); void setfft(std::string device_in,std::string precision_in); void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index cd50d7ef86..e3b72b6a67 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -243,7 +243,7 @@ class PW_Basis int nmaxgr=0; // Gamma_only: max between npw and (nrxx+1)/2, others: max between npw and nrxx // Thus complex[nmaxgr] is able to contain either reciprocal or real data FFT ft; - FFT1 ft1; + FFT_TEMP ft1; //The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform). template From 8d73d7e2490e6784f9dfe1c598957b7aee639a88 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 14:46:49 +0800 Subject: [PATCH 04/27] modify the position of the new fft --- source/Makefile.Objects | 6 +- source/module_base/CMakeLists.txt | 2 + .../module_fft}/fft_base.cpp | 0 .../module_fft}/fft_base.h | 0 .../module_fft}/fft_temp.h | 0 source/module_basis/module_pw/CMakeLists.txt | 2 - source/module_basis/module_pw/fft_temp.cpp | 280 ------------------ source/module_basis/module_pw/pw_basis.h | 2 +- .../module_basis/module_pw/pw_transform.cpp | 2 +- .../module_pw/test/CMakeLists.txt | 3 +- 10 files changed, 9 insertions(+), 288 deletions(-) rename source/{module_basis/module_pw => module_base/module_fft}/fft_base.cpp (100%) rename source/{module_basis/module_pw => module_base/module_fft}/fft_base.h (100%) rename source/{module_basis/module_pw => module_base/module_fft}/fft_temp.h (100%) delete mode 100644 source/module_basis/module_pw/fft_temp.cpp diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 4d1f7765c4..64310d83cf 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -25,6 +25,7 @@ VPATH=./src_global:\ ./module_base/module_container/ATen/ops:\ ./module_base/module_device:\ ./module_base/module_mixing:\ +./moudle_base/module_fft:\ ./module_md:\ ./module_basis/module_pw:\ ./module_esolver:\ @@ -166,7 +167,8 @@ OBJS_BASE=abfs-vector3_order.o\ broyden_mixing.o\ memory_op.o\ device.o\ - + fft_base.o\ + fft_temp.o\ OBJS_CELL=atom_pseudo.o\ atom_spec.o\ @@ -411,8 +413,6 @@ OBJS_PSI_INITIALIZER=psi_initializer.o\ psi_initializer_nao_random.o\ OBJS_PW=fft.o\ - fft_base.o\ - fft_temp.o\ pw_basis.o\ pw_basis_k.o\ pw_basis_sup.o\ diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index e11141208c..ad03649e41 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -58,6 +58,8 @@ add_library( module_mixing/plain_mixing.cpp module_mixing/pulay_mixing.cpp module_mixing/broyden_mixing.cpp + module_fft/fft_base.cpp + module_fft/fft_temp.cpp ${LIBM_SRC} ) diff --git a/source/module_basis/module_pw/fft_base.cpp b/source/module_base/module_fft/fft_base.cpp similarity index 100% rename from source/module_basis/module_pw/fft_base.cpp rename to source/module_base/module_fft/fft_base.cpp diff --git a/source/module_basis/module_pw/fft_base.h b/source/module_base/module_fft/fft_base.h similarity index 100% rename from source/module_basis/module_pw/fft_base.h rename to source/module_base/module_fft/fft_base.h diff --git a/source/module_basis/module_pw/fft_temp.h b/source/module_base/module_fft/fft_temp.h similarity index 100% rename from source/module_basis/module_pw/fft_temp.h rename to source/module_base/module_fft/fft_temp.h diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index abfde45a34..2b2d897206 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -1,7 +1,5 @@ list(APPEND objects fft.cpp - fft_base.cpp - fft_temp.cpp pw_basis.cpp pw_basis_k.cpp pw_basis_sup.cpp diff --git a/source/module_basis/module_pw/fft_temp.cpp b/source/module_basis/module_pw/fft_temp.cpp deleted file mode 100644 index c7553774d1..0000000000 --- a/source/module_basis/module_pw/fft_temp.cpp +++ /dev/null @@ -1,280 +0,0 @@ -#include -#include "fft_temp.h" -// #include "fft_cpu.h" -// #if defined(__CUDA) -// #include "fft_cuda.h" -// #endif -// #if defined(__ROCM) -// #include "fft_rcom.h" -// #endif -// #include "module_base/module_device/device.h" -// #include "fft_gpu.h" -FFT_TEMP::FFT_TEMP() -{ - fft_float = nullptr; - fft_double = nullptr; -} -FFT_TEMP::FFT_TEMP(std::string device_in,std::string precision_in) -{ - assert(device_in=="cpu" || device_in=="gpu"); - assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); - this->device = device_in; - this->precision = precision_in; - // if (device=="cpu") - // { - // fft_float = new FFT_CPU(); - // fft_double = new FFT_CPU(); - // } - // else if (device=="gpu") - // { - // #if defined(__ROCM) - // fft_float = new FFT_RCOM(); - // fft_double = new FFT_RCOM(); - // #elif defined(__CUDA) - // fft_float = new FFT_CUDA(); - // fft_double = new FFT_CUDA(); - // #endif - // } -} - -FFT_TEMP::~FFT_TEMP() -{ - if (float_flag) - { - delete fft_float; - fft_float=nullptr; - } - if (double_flag) - { - delete fft_double; - fft_double=nullptr; - } -} - -// void FFT_TEMP::set_device(std::string device_in) -// { -// this->device = device_in; -// } - -// void FFT_TEMP::set_precision(std::string precision_in) -// { -// this->precision = precision_in; -// } -// void FFT_TEMP::setfft(std::string device_in,std::string precision_in) -// { -// assert(device_in=="cpu" || device_in=="gpu"); -// assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); -// this->device = device_in; -// this->precision = precision_in; -// if (device=="cpu") -// { -// fft_float = new FFT_CPU(); -// fft_double = new FFT_CPU(); -// } -// else if (device=="gpu") -// { -// #if defined(__ROCM) -// fft_float = new FFT_RCOM(); -// fft_double = new FFT_RCOM(); -// #elif defined(__CUDA) -// fft_float = new FFT_CUDA(); -// fft_double = new FFT_CUDA(); -// #endif -// } -// } -// void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, -// int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) -// { -// if (this->precision=="single") -// { -// float_flag = 1; -// } -// else if (this->precision=="double") -// { -// double_flag = 1; -// } -// else if (this->precision=="mixing") -// { -// float_flag = 1; -// double_flag = 1; -// } -// if (float_flag) -// { -// fft_float->initfftmode(this->fft_mode); -// fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); -// } -// if (double_flag) -// { -// fft_double->initfftmode(this->fft_mode); -// fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); -// } -// } -// void FFT_TEMP::initfftmode(int fft_mode_in) -// { -// this->fft_mode = fft_mode_in; -// } - -// void FFT_TEMP::setupFFT() -// { -// if (double_flag) -// { -// fft_double->setupFFT(); -// } -// if (float_flag) -// { -// fft_float->setupFFT(); -// } -// } - -// void FFT_TEMP::clearFFT() -// { -// if (double_flag) -// { -// fft_double->cleanFFT(); -// } -// if (float_flag) -// { -// fft_float->cleanFFT(); -// } -// } -// void FFT_TEMP::clear() -// { -// this->clearFFT(); -// if (float_flag) -// { -// fft_float->clear(); -// } -// if (double_flag) -// { -// fft_double->clear(); -// } -// } -// // access the real space data -// template <> -// float* FFT_TEMP::get_rspace_data() const -// { -// return fft_float->get_rspace_data(); -// } - -// template <> -// double* FFT_TEMP::get_rspace_data() const -// { -// return fft_double->get_rspace_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxr_data() const -// { -// return fft_float->get_auxr_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxr_data() const -// { -// return fft_double->get_auxr_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxg_data() const -// { -// return fft_float->get_auxg_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxg_data() const -// { -// return fft_double->get_auxg_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxr_3d_data() const -// { -// return fft_float->get_auxr_3d_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxr_3d_data() const -// { -// return fft_double->get_auxr_3d_data(); -// } -// template <> -// void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const -// { -// fft_float->fftxyfor(in,out); -// } - -// template <> -// void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const -// { -// fft_double->fftxyfor(in,out); -// } - -// template <> -// void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const -// { -// fft_float->fftzfor(in,out); -// } -// template <> -// void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const -// { -// fft_double->fftzfor(in,out); -// } - -// template <> -// void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const -// { -// fft_float->fftxybac(in,out); -// } -// template <> -// void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const -// { -// fft_double->fftxybac(in,out); -// } - -// template <> -// void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const -// { -// fft_float->fftzbac(in,out); -// } -// template <> -// void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const -// { -// fft_double->fftzbac(in,out); -// } -// template <> -// void FFT_TEMP::fftxyr2c(float* in, std::complex* out) const -// { -// fft_float->fftxyr2c(in,out); -// } -// template <> -// void FFT_TEMP::fftxyr2c(double* in, std::complex* out) const -// { -// fft_double->fftxyr2c(in,out); -// } - -// template <> -// void FFT_TEMP::fftxyc2r(std::complex* in, float* out) const -// { -// fft_float->fftxyc2r(in,out); -// } -// template <> -// void FFT_TEMP::fftxyc2r(std::complex* in, double* out) const -// { -// fft_double->fftxyc2r(in,out); -// } - -// template <> -// void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -// { -// fft_float->fft3D_forward(in, out); -// } - -// template <> -// void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -// { -// fft_double->fft3D_forward(in, out); -// } -// template <> -// void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -// { -// fft_float->fft3D_backward(in, out); -// } -// template <> -// void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -// { -// fft_double->fft3D_backward(in, out); -// } \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index e3b72b6a67..f9966dc78a 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -6,7 +6,7 @@ #include "module_base/vector3.h" #include #include "fft.h" -#include "fft_temp.h" +#include "module_base/module_fft/fft_temp.h" #include #ifdef __MPI #include "mpi.h" diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp index 6885f38731..43d86c4381 100644 --- a/source/module_basis/module_pw/pw_transform.cpp +++ b/source/module_basis/module_pw/pw_transform.cpp @@ -1,5 +1,5 @@ #include "fft.h" -#include "fft_temp.h" +#include "module_base/module_fft/fft_temp.h" #include #include "pw_basis.h" #include diff --git a/source/module_basis/module_pw/test/CMakeLists.txt b/source/module_basis/module_pw/test/CMakeLists.txt index e1ce122d07..a5f40f6127 100644 --- a/source/module_basis/module_pw/test/CMakeLists.txt +++ b/source/module_basis/module_pw/test/CMakeLists.txt @@ -4,7 +4,8 @@ AddTest( LIBS parameter ${math_libs} planewave device SOURCES ../../../module_base/matrix.cpp ../../../module_base/complexmatrix.cpp ../../../module_base/matrix3.cpp ../../../module_base/tool_quit.cpp ../../../module_base/mymath.cpp ../../../module_base/timer.cpp ../../../module_base/memory.cpp ../../../module_base/blas_connector.cpp - ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp + ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp + ../../../module_base/module_fft/fft_base.cpp ../../../module_base/module_fft/fft_temp.cpp # ../../../module_psi/kernels/psi_memory_op.cpp ../../../module_base/module_device/memory_op.cpp depend_mock.cpp pw_test.cpp test1-1-1.cpp test1-1-2.cpp test1-2.cpp test1-3.cpp test1-4.cpp test1-5.cpp From 5304827476cfedced44e8eb31e5ace0f1cffe4be Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 15:06:16 +0800 Subject: [PATCH 05/27] modify the Makefile --- source/Makefile.Objects | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 64310d83cf..1d9724c395 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -25,7 +25,7 @@ VPATH=./src_global:\ ./module_base/module_container/ATen/ops:\ ./module_base/module_device:\ ./module_base/module_mixing:\ -./moudle_base/module_fft:\ +./module_base/module_fft:\ ./module_md:\ ./module_basis/module_pw:\ ./module_esolver:\ @@ -167,8 +167,8 @@ OBJS_BASE=abfs-vector3_order.o\ broyden_mixing.o\ memory_op.o\ device.o\ - fft_base.o\ fft_temp.o\ + fft_base.o\ OBJS_CELL=atom_pseudo.o\ atom_spec.o\ From 4049c762822825e76bbdb77820e1742ac914771e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Tue, 5 Nov 2024 07:20:03 +0000 Subject: [PATCH 06/27] [pre-commit.ci lite] apply automatic fixes --- source/module_base/module_fft/fft_base.cpp | 5 +++-- source/module_base/module_fft/fft_temp.h | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/source/module_base/module_fft/fft_base.cpp b/source/module_base/module_fft/fft_base.cpp index 31fb9881e1..f17a6b0999 100644 --- a/source/module_base/module_fft/fft_base.cpp +++ b/source/module_base/module_fft/fft_base.cpp @@ -18,10 +18,11 @@ void FFT_BASE::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int this->fftny = this->ny = ny_in; if (this->gamma_only) { - if (xprime) + if (xprime) { this->fftnx = int(nx / 2) + 1; - else + } else { this->fftny = int(ny / 2) + 1; +} } this->nz = nz_in; this->ns = ns_in; diff --git a/source/module_base/module_fft/fft_temp.h b/source/module_base/module_fft/fft_temp.h index 10165fc5c9..9d08319efb 100644 --- a/source/module_base/module_fft/fft_temp.h +++ b/source/module_base/module_fft/fft_temp.h @@ -54,8 +54,8 @@ class FFT_TEMP private: int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive - bool float_flag=0; - bool double_flag=0; + bool float_flag=false; + bool double_flag=false; FFT_BASE* fft_float=nullptr; FFT_BASE* fft_double=nullptr; From 5cfd6bc8bbe780ce3b8e5cd4453b4f2685e07d2e Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 15:24:38 +0800 Subject: [PATCH 07/27] add the cpu float in the fft floder --- source/Makefile.Objects | 1 + source/module_base/CMakeLists.txt | 8 +- source/module_base/module_fft/fft_cpu.cpp | 307 ++++++++++++ source/module_base/module_fft/fft_cpu.h | 83 ++++ .../module_base/module_fft/fft_cpu_float.cpp | 306 ++++++++++++ source/module_base/module_fft/fft_temp.cpp | 447 +++++++++--------- 6 files changed, 928 insertions(+), 224 deletions(-) create mode 100644 source/module_base/module_fft/fft_cpu.cpp create mode 100644 source/module_base/module_fft/fft_cpu.h create mode 100644 source/module_base/module_fft/fft_cpu_float.cpp diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 1d9724c395..bbfd623cd9 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -169,6 +169,7 @@ OBJS_BASE=abfs-vector3_order.o\ device.o\ fft_temp.o\ fft_base.o\ + fft_cpu.o\ OBJS_CELL=atom_pseudo.o\ atom_spec.o\ diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index ad03649e41..e257600ad2 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -6,7 +6,11 @@ list (APPEND LIBM_SRC libm/sincos.cpp ) endif() - +if (ENABLE_FLOAT_FFTW) + list (APPEND FFT_SRC + module_fft/fftw_float.cpp + ) +endif() add_library( base OBJECT @@ -60,7 +64,9 @@ add_library( module_mixing/broyden_mixing.cpp module_fft/fft_base.cpp module_fft/fft_temp.cpp + module_fft/fft_cpu.cpp ${LIBM_SRC} + ${FFT_SRC} ) add_subdirectory(module_container) diff --git a/source/module_base/module_fft/fft_cpu.cpp b/source/module_base/module_fft/fft_cpu.cpp new file mode 100644 index 0000000000..8cbb78639d --- /dev/null +++ b/source/module_base/module_fft/fft_cpu.cpp @@ -0,0 +1,307 @@ +#include "fft_cpu.h" +#include "fftw3.h" +#if defined(__FFTW3_MPI) && defined(__MPI) +#include +//#include "fftw3-mpi_mkl.h" +#endif + +template <> +FFT_CPU::FFT_CPU() +{ + +} +template <> +FFT_CPU::~FFT_CPU() +{ + +} + +template <> +void FFT_CPU::setupFFT() +{ + + unsigned int flag = FFTW_ESTIMATE; + switch (this->fft_mode) + { + case 0: + flag = FFTW_ESTIMATE; + break; + case 1: + flag = FFTW_MEASURE; + break; + case 2: + flag = FFTW_PATIENT; + break; + case 3: + flag = FFTW_EXHAUSTIVE; + break; + default: + break; + } + if (!this->mpifft) + { + z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); + z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); + d_rspace = (double*)z_auxg; + this->planzfor = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, + (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); + + this->planzbac = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, + (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); + + //--------------------------------------------------------- + // 2 D - XY + //--------------------------------------------------------- + // 1D+1D is much faster than 2D FFT! + // in-place fft is better for c2c and out-of-place fft is better for c2r + int* embed = nullptr; + int npy = this->nplane * this->ny; + if (this->xprime) + { + this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, + (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_FORWARD, flag); + this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, + (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_BACKWARD, flag); + if (this->gamma_only) + { + this->planxr2c = fftw_plan_many_dft_r2c(1, &this->nx, npy, d_rspace, embed, npy, 1, (fftw_complex*)z_auxr, + embed, npy, 1, flag); + this->planxc2r = fftw_plan_many_dft_c2r(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, d_rspace, + embed, npy, 1, flag); + } + else + { + this->planxfor1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, + (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planxbac1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, + (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + } + } + else + { + this->planxfor1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, + 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planxbac1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, + 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + if (this->gamma_only) + { + this->planyr2c = fftw_plan_many_dft_r2c(1, &this->ny, this->nplane, d_rspace, embed, this->nplane, 1, + (fftw_complex*)z_auxr, embed, this->nplane, 1, flag); + this->planyc2r = fftw_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, + this->nplane, 1, d_rspace, embed, this->nplane, 1, flag); + } + else + { + + this->planxfor2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, + npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planxbac2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, + npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, + 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); + this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, + 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); + } + } + } +#if defined(__FFTW3_MPI) && defined(__MPI) + else + { + // this->initplan_mpi(); + // if (this->precision == "single") { + // this->initplanf_mpi(); + // } + } +#endif + return; +} +template <> +void FFT_CPU::initfftmode(int fft_mode_in) +{ + this->fft_mode = fft_mode_in; +} + +template <> +void FFT_CPU::clearfft(fftw_plan& plan) +{ + if (plan) + { + fftw_destroy_plan(plan); + plan = NULL; + } +} + +template <> +void FFT_CPU::cleanFFT() +{ + printf("in the double cleanFFT\n"); + clearfft(planzfor); + clearfft(planzbac); + clearfft(planxfor1); + clearfft(planxbac1); + clearfft(planxfor2); + clearfft(planxbac2); + clearfft(planyfor); + clearfft(planybac); + clearfft(planxr2c); + clearfft(planxc2r); + clearfft(planyr2c); + clearfft(planyc2r); +} + + +template <> +void FFT_CPU::clear() +{ + this->cleanFFT(); + if (z_auxg != nullptr) + { + fftw_free(z_auxg); + z_auxg = nullptr; + } + if (z_auxr != nullptr) + { + fftw_free(z_auxr); + z_auxr = nullptr; + } + d_rspace = nullptr; +} + +template <> +double* FFT_CPU::get_rspace_data() const +{ + return d_rspace; +} +template <> +std::complex* FFT_CPU::get_auxr_data() const +{ + return z_auxr; +} +template <> +std::complex* FFT_CPU::get_auxg_data() const +{ + return z_auxg; +} +template <> +void FFT_CPU::fftxyfor(std::complex* in, std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + for (int i = rixy; i < this->nx; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + } + else + { + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + + fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); + fftw_execute_dft(this->planxfor2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); + } +} +template <> +void FFT_CPU::fftxybac(std::complex* in,std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + for (int i = rixy; i < this->nx; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); + } + else + { + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + + fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); + fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); + } +} +template <> +void FFT_CPU::fftzfor(std::complex* in, std::complex* out) const +{ + fftw_execute_dft(this->planzfor, (fftw_complex*)in, (fftw_complex*)out); +} +template <> +void FFT_CPU::fftzbac(std::complex* in, std::complex* out) const +{ + fftw_execute_dft(this->planzbac, (fftw_complex*)in, (fftw_complex*)out); +} +template <> +void FFT_CPU::fftxyr2c(double* in, std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + fftw_execute_dft_r2c(this->planxr2c, in, (fftw_complex*)out); + + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]); + } + } + else + { + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft_r2c(this->planyr2c, &in[i * npy], (fftw_complex*)&out[i * npy]); + } + + fftw_execute_dft(this->planxfor1, (fftw_complex*)out, (fftw_complex*)out); + } +} + +template <> +void FFT_CPU::fftxyc2r(std::complex *in,double *out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&in[i * npy]); + } + + fftw_execute_dft_c2r(this->planxc2r, (fftw_complex*)in, out); + } + else + { + fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)in); + + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft_c2r(this->planyc2r, (fftw_complex*)&in[i * npy], &out[i * npy]); + } + } +} +template <> +FFT_CPU::FFT_CPU() +{ + +} +template <> +FFT_CPU::~FFT_CPU() +{ + +} +template FFT_CPU::FFT_CPU(); +template FFT_CPU::FFT_CPU(); diff --git a/source/module_base/module_fft/fft_cpu.h b/source/module_base/module_fft/fft_cpu.h new file mode 100644 index 0000000000..981dd1d067 --- /dev/null +++ b/source/module_base/module_fft/fft_cpu.h @@ -0,0 +1,83 @@ +#include "fft_base.h" +#include "fftw3.h" + +// #ifdef __ENABLE_FLOAT_FFTW + +// #endif +// #endif +#ifndef FFT_CPU_H +#define FFT_CPU_H + +template +class FFT_CPU : public FFT_BASE +{ + public: + __attribute__((weak)) FFT_CPU(); + __attribute__((weak)) ~FFT_CPU(); + + __attribute__((weak)) void initfftmode(int fft_mode_in); + + //init fftw_plans + __attribute__((weak)) void setupFFT() override; + + // void initplan(const unsigned int& flag = 0); + __attribute__((weak)) void cleanFFT() override; + + __attribute__((weak)) void clear() override; + + __attribute__((weak)) FPTYPE* get_rspace_data() const override; + + __attribute__((weak)) std::complex* get_auxr_data() const; + + __attribute__((weak)) std::complex* get_auxg_data() const; + + __attribute__((weak)) void fftxyfor(std::complex* in, std::complex* out) const override; + + __attribute__((weak)) void fftxybac(std::complex* in, std::complex* out) const override; + + __attribute__((weak)) void fftzfor(std::complex* in, std::complex* out) const override; + + __attribute__((weak)) void fftzbac(std::complex* in, std::complex* out) const override; + + __attribute__((weak)) void fftxyr2c(FPTYPE* in, std::complex* out) const override; + + __attribute__((weak)) void fftxyc2r(std::complex* in, FPTYPE* out) const override; + private: + void clearfft(fftw_plan& plan); + void clearfft(fftwf_plan& plan); + + fftw_plan planzfor = NULL;//create a special pointer pointing to the fftw_plan class as a plan for performing FFT + fftw_plan planzbac = NULL; + fftw_plan planxfor1 = NULL; + fftw_plan planxbac1 = NULL; + fftw_plan planxfor2 = NULL; + fftw_plan planxbac2 = NULL; + fftw_plan planyfor = NULL; + fftw_plan planybac = NULL; + fftw_plan planxr2c = NULL; + fftw_plan planxc2r = NULL; + fftw_plan planyr2c = NULL; + fftw_plan planyc2r = NULL; + + fftwf_plan planfzfor = NULL; + fftwf_plan planfzbac = NULL; + fftwf_plan planfxfor1= NULL; + fftwf_plan planfxbac1= NULL; + fftwf_plan planfxfor2= NULL; + fftwf_plan planfxbac2= NULL; + fftwf_plan planfyfor = NULL; + fftwf_plan planfybac = NULL; + fftwf_plan planfxr2c = NULL; + fftwf_plan planfxc2r = NULL; + fftwf_plan planfyr2c = NULL; + fftwf_plan planfyc2r = NULL; + + std::complex*c_auxg = nullptr; + std::complex*c_auxr = nullptr; // fft space + std::complex*z_auxg = nullptr; + std::complex*z_auxr = nullptr; // fft space + + float* s_rspace = nullptr; // real number space for r, [nplane * nx *ny] + double* d_rspace = nullptr; // real number space for r, [nplane * nx *ny] +}; +#endif // FFT_CPU_H \ No newline at end of file diff --git a/source/module_base/module_fft/fft_cpu_float.cpp b/source/module_base/module_fft/fft_cpu_float.cpp new file mode 100644 index 0000000000..3340603508 --- /dev/null +++ b/source/module_base/module_fft/fft_cpu_float.cpp @@ -0,0 +1,306 @@ +#include "fft_cpu.h" +#include "fftw3.h" +#if defined(__FFTW3_MPI) && defined(__MPI) +#include +//#include "fftw3-mpi_mkl.h" +#endif + +template <> +FFT_CPU::FFT_CPU() +{ + +} +template <> +FFT_CPU::~FFT_CPU() +{ + +} + +template <> +void FFT_CPU::setupFFT() +{ + + unsigned int flag = FFTW_ESTIMATE; + switch (this->fft_mode) + { + case 0: + flag = FFTW_ESTIMATE; + break; + case 1: + flag = FFTW_MEASURE; + break; + case 2: + flag = FFTW_PATIENT; + break; + case 3: + flag = FFTW_EXHAUSTIVE; + break; + default: + break; + } + if (!this->mpifft) + { + z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); + z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); + d_rspace = (float*)z_auxg; + this->planzfor = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, + (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); + + this->planzbac = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, + (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); + + //--------------------------------------------------------- + // 2 D - XY + //--------------------------------------------------------- + // 1D+1D is much faster than 2D FFT! + // in-place fft is better for c2c and out-of-place fft is better for c2r + int* embed = nullptr; + int npy = this->nplane * this->ny; + if (this->xprime) + { + this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, + (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_FORWARD, flag); + this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, + (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_BACKWARD, flag); + if (this->gamma_only) + { + this->planxr2c = fftw_plan_many_dft_r2c(1, &this->nx, npy, d_rspace, embed, npy, 1, (fftw_complex*)z_auxr, + embed, npy, 1, flag); + this->planxc2r = fftw_plan_many_dft_c2r(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, d_rspace, + embed, npy, 1, flag); + } + else + { + this->planxfor1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, + (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planxbac1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, + (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + } + } + else + { + this->planxfor1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, + 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planxbac1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, + 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + if (this->gamma_only) + { + this->planyr2c = fftw_plan_many_dft_r2c(1, &this->ny, this->nplane, d_rspace, embed, this->nplane, 1, + (fftw_complex*)z_auxr, embed, this->nplane, 1, flag); + this->planyc2r = fftw_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, + this->nplane, 1, d_rspace, embed, this->nplane, 1, flag); + } + else + { + + this->planxfor2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, + npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planxbac2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, + npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, + 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); + this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, + 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); + } + } + } +#if defined(__FFTW3_MPI) && defined(__MPI) + else + { + // this->initplan_mpi(); + // if (this->precision == "single") { + // this->initplanf_mpi(); + // } + } +#endif + return; +} +template <> +void FFT_CPU::initfftmode(int fft_mode_in) +{ + this->fft_mode = fft_mode_in; +} + +template <> +void FFT_CPU::clearfft(fftw_plan& plan) +{ + if (plan) + { + fftw_destroy_plan(plan); + plan = NULL; + } +} + +template <> +void FFT_CPU::cleanFFT() +{ + clearfft(planzfor); + clearfft(planzbac); + clearfft(planxfor1); + clearfft(planxbac1); + clearfft(planxfor2); + clearfft(planxbac2); + clearfft(planyfor); + clearfft(planybac); + clearfft(planxr2c); + clearfft(planxc2r); + clearfft(planyr2c); + clearfft(planyc2r); +} + + +template <> +void FFT_CPU::clear() +{ + this->cleanFFT(); + if (z_auxg != nullptr) + { + fftw_free(z_auxg); + z_auxg = nullptr; + } + if (z_auxr != nullptr) + { + fftw_free(z_auxr); + z_auxr = nullptr; + } + d_rspace = nullptr; +} + +template <> +float* FFT_CPU::get_rspace_data() const +{ + return d_rspace; +} +template <> +std::complex* FFT_CPU::get_auxr_data() const +{ + return z_auxr; +} +template <> +std::complex* FFT_CPU::get_auxg_data() const +{ + return z_auxg; +} +template <> +void FFT_CPU::fftxyfor(std::complex* in, std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + for (int i = rixy; i < this->nx; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + } + else + { + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + + fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); + fftw_execute_dft(this->planxfor2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); + } +} +template <> +void FFT_CPU::fftxybac(std::complex* in,std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + for (int i = rixy; i < this->nx; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); + } + else + { + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + } + + fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); + fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); + } +} +template <> +void FFT_CPU::fftzfor(std::complex* in, std::complex* out) const +{ + fftw_execute_dft(this->planzfor, (fftw_complex*)in, (fftw_complex*)out); +} +template <> +void FFT_CPU::fftzbac(std::complex* in, std::complex* out) const +{ + fftw_execute_dft(this->planzbac, (fftw_complex*)in, (fftw_complex*)out); +} +template <> +void FFT_CPU::fftxyr2c(float* in, std::complex* out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + fftw_execute_dft_r2c(this->planxr2c, in, (fftw_complex*)out); + + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]); + } + } + else + { + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft_r2c(this->planyr2c, &in[i * npy], (fftw_complex*)&out[i * npy]); + } + + fftw_execute_dft(this->planxfor1, (fftw_complex*)out, (fftw_complex*)out); + } +} + +template <> +void FFT_CPU::fftxyc2r(std::complex *in,float *out) const +{ + int npy = this->nplane * this->ny; + if (this->xprime) + { + for (int i = 0; i < this->lixy + 1; ++i) + { + fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&in[i * npy]); + } + + fftw_execute_dft_c2r(this->planxc2r, (fftw_complex*)in, out); + } + else + { + fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)in); + + for (int i = 0; i < this->nx; ++i) + { + fftw_execute_dft_c2r(this->planyc2r, (fftw_complex*)&in[i * npy], &out[i * npy]); + } + } +} +template <> +FFT_CPU::FFT_CPU() +{ + +} +template <> +FFT_CPU::~FFT_CPU() +{ + +} +template FFT_CPU::FFT_CPU(); +template FFT_CPU::FFT_CPU(); diff --git a/source/module_base/module_fft/fft_temp.cpp b/source/module_base/module_fft/fft_temp.cpp index c7553774d1..2682d60ac4 100644 --- a/source/module_base/module_fft/fft_temp.cpp +++ b/source/module_base/module_fft/fft_temp.cpp @@ -1,13 +1,14 @@ #include #include "fft_temp.h" -// #include "fft_cpu.h" +#include "fft_cpu.h" +#include "module_base/module_device/device.h" // #if defined(__CUDA) // #include "fft_cuda.h" // #endif // #if defined(__ROCM) // #include "fft_rcom.h" // #endif -// #include "module_base/module_device/device.h" + // #include "fft_gpu.h" FFT_TEMP::FFT_TEMP() { @@ -20,11 +21,11 @@ FFT_TEMP::FFT_TEMP(std::string device_in,std::string precision_in) assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); this->device = device_in; this->precision = precision_in; - // if (device=="cpu") - // { - // fft_float = new FFT_CPU(); - // fft_double = new FFT_CPU(); - // } + if (device=="cpu") + { + fft_float = new FFT_CPU(); + fft_double = new FFT_CPU(); + } // else if (device=="gpu") // { // #if defined(__ROCM) @@ -51,230 +52,230 @@ FFT_TEMP::~FFT_TEMP() } } -// void FFT_TEMP::set_device(std::string device_in) -// { -// this->device = device_in; -// } +void FFT_TEMP::set_device(std::string device_in) +{ + this->device = device_in; +} -// void FFT_TEMP::set_precision(std::string precision_in) -// { -// this->precision = precision_in; -// } -// void FFT_TEMP::setfft(std::string device_in,std::string precision_in) -// { -// assert(device_in=="cpu" || device_in=="gpu"); -// assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); -// this->device = device_in; -// this->precision = precision_in; -// if (device=="cpu") -// { -// fft_float = new FFT_CPU(); -// fft_double = new FFT_CPU(); -// } -// else if (device=="gpu") -// { -// #if defined(__ROCM) -// fft_float = new FFT_RCOM(); -// fft_double = new FFT_RCOM(); -// #elif defined(__CUDA) -// fft_float = new FFT_CUDA(); -// fft_double = new FFT_CUDA(); -// #endif -// } -// } -// void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, -// int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) -// { -// if (this->precision=="single") -// { -// float_flag = 1; -// } -// else if (this->precision=="double") -// { -// double_flag = 1; -// } -// else if (this->precision=="mixing") -// { -// float_flag = 1; -// double_flag = 1; -// } -// if (float_flag) -// { -// fft_float->initfftmode(this->fft_mode); -// fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); -// } -// if (double_flag) -// { -// fft_double->initfftmode(this->fft_mode); -// fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); -// } -// } -// void FFT_TEMP::initfftmode(int fft_mode_in) -// { -// this->fft_mode = fft_mode_in; -// } +void FFT_TEMP::set_precision(std::string precision_in) +{ + this->precision = precision_in; +} +void FFT_TEMP::setfft(std::string device_in,std::string precision_in) +{ + assert(device_in=="cpu" || device_in=="gpu"); + assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); + this->device = device_in; + this->precision = precision_in; + if (device=="cpu") + { + fft_float = new FFT_CPU(); + fft_double = new FFT_CPU(); + } + else if (device=="gpu") + { + #if defined(__ROCM) + fft_float = new FFT_RCOM(); + fft_double = new FFT_RCOM(); + #elif defined(__CUDA) + fft_float = new FFT_CUDA(); + fft_double = new FFT_CUDA(); + #endif + } +} +void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) +{ + if (this->precision=="single") + { + float_flag = 1; + } + else if (this->precision=="double") + { + double_flag = 1; + } + else if (this->precision=="mixing") + { + float_flag = 1; + double_flag = 1; + } + if (float_flag) + { + fft_float->initfftmode(this->fft_mode); + fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); + } + if (double_flag) + { + fft_double->initfftmode(this->fft_mode); + fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); + } +} +void FFT_TEMP::initfftmode(int fft_mode_in) +{ + this->fft_mode = fft_mode_in; +} -// void FFT_TEMP::setupFFT() -// { -// if (double_flag) -// { -// fft_double->setupFFT(); -// } -// if (float_flag) -// { -// fft_float->setupFFT(); -// } -// } +void FFT_TEMP::setupFFT() +{ + if (double_flag) + { + fft_double->setupFFT(); + } + if (float_flag) + { + fft_float->setupFFT(); + } +} -// void FFT_TEMP::clearFFT() -// { -// if (double_flag) -// { -// fft_double->cleanFFT(); -// } -// if (float_flag) -// { -// fft_float->cleanFFT(); -// } -// } -// void FFT_TEMP::clear() -// { -// this->clearFFT(); -// if (float_flag) -// { -// fft_float->clear(); -// } -// if (double_flag) -// { -// fft_double->clear(); -// } -// } -// // access the real space data -// template <> -// float* FFT_TEMP::get_rspace_data() const -// { -// return fft_float->get_rspace_data(); -// } +void FFT_TEMP::clearFFT() +{ + if (double_flag) + { + fft_double->cleanFFT(); + } + if (float_flag) + { + fft_float->cleanFFT(); + } +} +void FFT_TEMP::clear() +{ + this->clearFFT(); + if (float_flag) + { + fft_float->clear(); + } + if (double_flag) + { + fft_double->clear(); + } +} +// access the real space data +template <> +float* FFT_TEMP::get_rspace_data() const +{ + return fft_float->get_rspace_data(); +} -// template <> -// double* FFT_TEMP::get_rspace_data() const -// { -// return fft_double->get_rspace_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxr_data() const -// { -// return fft_float->get_auxr_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxr_data() const -// { -// return fft_double->get_auxr_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxg_data() const -// { -// return fft_float->get_auxg_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxg_data() const -// { -// return fft_double->get_auxg_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxr_3d_data() const -// { -// return fft_float->get_auxr_3d_data(); -// } -// template <> -// std::complex* FFT_TEMP::get_auxr_3d_data() const -// { -// return fft_double->get_auxr_3d_data(); -// } -// template <> -// void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const -// { -// fft_float->fftxyfor(in,out); -// } +template <> +double* FFT_TEMP::get_rspace_data() const +{ + return fft_double->get_rspace_data(); +} +template <> +std::complex* FFT_TEMP::get_auxr_data() const +{ + return fft_float->get_auxr_data(); +} +template <> +std::complex* FFT_TEMP::get_auxr_data() const +{ + return fft_double->get_auxr_data(); +} +template <> +std::complex* FFT_TEMP::get_auxg_data() const +{ + return fft_float->get_auxg_data(); +} +template <> +std::complex* FFT_TEMP::get_auxg_data() const +{ + return fft_double->get_auxg_data(); +} +template <> +std::complex* FFT_TEMP::get_auxr_3d_data() const +{ + return fft_float->get_auxr_3d_data(); +} +template <> +std::complex* FFT_TEMP::get_auxr_3d_data() const +{ + return fft_double->get_auxr_3d_data(); +} +template <> +void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const +{ + fft_float->fftxyfor(in,out); +} -// template <> -// void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const -// { -// fft_double->fftxyfor(in,out); -// } +template <> +void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const +{ + fft_double->fftxyfor(in,out); +} -// template <> -// void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const -// { -// fft_float->fftzfor(in,out); -// } -// template <> -// void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const -// { -// fft_double->fftzfor(in,out); -// } +template <> +void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const +{ + fft_float->fftzfor(in,out); +} +template <> +void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const +{ + fft_double->fftzfor(in,out); +} -// template <> -// void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const -// { -// fft_float->fftxybac(in,out); -// } -// template <> -// void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const -// { -// fft_double->fftxybac(in,out); -// } +template <> +void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const +{ + fft_float->fftxybac(in,out); +} +template <> +void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const +{ + fft_double->fftxybac(in,out); +} -// template <> -// void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const -// { -// fft_float->fftzbac(in,out); -// } -// template <> -// void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const -// { -// fft_double->fftzbac(in,out); -// } -// template <> -// void FFT_TEMP::fftxyr2c(float* in, std::complex* out) const -// { -// fft_float->fftxyr2c(in,out); -// } -// template <> -// void FFT_TEMP::fftxyr2c(double* in, std::complex* out) const -// { -// fft_double->fftxyr2c(in,out); -// } +template <> +void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const +{ + fft_float->fftzbac(in,out); +} +template <> +void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const +{ + fft_double->fftzbac(in,out); +} +template <> +void FFT_TEMP::fftxyr2c(float* in, std::complex* out) const +{ + fft_float->fftxyr2c(in,out); +} +template <> +void FFT_TEMP::fftxyr2c(double* in, std::complex* out) const +{ + fft_double->fftxyr2c(in,out); +} -// template <> -// void FFT_TEMP::fftxyc2r(std::complex* in, float* out) const -// { -// fft_float->fftxyc2r(in,out); -// } -// template <> -// void FFT_TEMP::fftxyc2r(std::complex* in, double* out) const -// { -// fft_double->fftxyc2r(in,out); -// } +template <> +void FFT_TEMP::fftxyc2r(std::complex* in, float* out) const +{ + fft_float->fftxyc2r(in,out); +} +template <> +void FFT_TEMP::fftxyc2r(std::complex* in, double* out) const +{ + fft_double->fftxyc2r(in,out); +} -// template <> -// void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -// { -// fft_float->fft3D_forward(in, out); -// } +template <> +void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +{ + fft_float->fft3D_forward(in, out); +} -// template <> -// void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -// { -// fft_double->fft3D_forward(in, out); -// } -// template <> -// void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -// { -// fft_float->fft3D_backward(in, out); -// } -// template <> -// void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const -// { -// fft_double->fft3D_backward(in, out); -// } \ No newline at end of file +template <> +void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +{ + fft_double->fft3D_forward(in, out); +} +template <> +void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +{ + fft_float->fft3D_backward(in, out); +} +template <> +void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +{ + fft_double->fft3D_backward(in, out); +} \ No newline at end of file From 9b8fb19625cdfc67e984ec439f8700354ae44bf3 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 15:38:01 +0800 Subject: [PATCH 08/27] change the test file --- source/module_base/module_fft/fft_temp.cpp | 14 +++++++------- .../module_xc/test/CMakeLists.txt | 6 ++++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/source/module_base/module_fft/fft_temp.cpp b/source/module_base/module_fft/fft_temp.cpp index 2682d60ac4..4301fac57f 100644 --- a/source/module_base/module_fft/fft_temp.cpp +++ b/source/module_base/module_fft/fft_temp.cpp @@ -74,13 +74,13 @@ void FFT_TEMP::setfft(std::string device_in,std::string precision_in) } else if (device=="gpu") { - #if defined(__ROCM) - fft_float = new FFT_RCOM(); - fft_double = new FFT_RCOM(); - #elif defined(__CUDA) - fft_float = new FFT_CUDA(); - fft_double = new FFT_CUDA(); - #endif + // #if defined(__ROCM) + // fft_float = new FFT_RCOM(); + // fft_double = new FFT_RCOM(); + // #elif defined(__CUDA) + // fft_float = new FFT_CUDA(); + // fft_double = new FFT_CUDA(); + // #endif } } void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, diff --git a/source/module_hamilt_general/module_xc/test/CMakeLists.txt b/source/module_hamilt_general/module_xc/test/CMakeLists.txt index 7466f40a92..5a7ed800e3 100644 --- a/source/module_hamilt_general/module_xc/test/CMakeLists.txt +++ b/source/module_hamilt_general/module_xc/test/CMakeLists.txt @@ -38,6 +38,9 @@ AddTest( ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp ../../../module_base/blas_connector.cpp + ../../../module_base/module_fft/fft_base.cpp + ../../../module_base/module_fft/fft_temp.cpp + ../../../module_base/module_fft/fft_cpu.cpp ) AddTest( @@ -73,4 +76,7 @@ AddTest( ../../../module_base/timer.cpp ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp + ../../../module_base/module_fft/fft_base.cpp + ../../../module_base/module_fft/fft_temp.cpp + ../../../module_base/module_fft/fft_cpu.cpp ) \ No newline at end of file From d7f50df265f3d9ee57a673007d44bb7195352b37 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Tue, 5 Nov 2024 08:17:31 +0000 Subject: [PATCH 09/27] [pre-commit.ci lite] apply automatic fixes --- source/module_base/module_fft/fft_cpu_float.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_base/module_fft/fft_cpu_float.cpp b/source/module_base/module_fft/fft_cpu_float.cpp index 3340603508..25951388bf 100644 --- a/source/module_base/module_fft/fft_cpu_float.cpp +++ b/source/module_base/module_fft/fft_cpu_float.cpp @@ -127,7 +127,7 @@ void FFT_CPU::clearfft(fftw_plan& plan) if (plan) { fftw_destroy_plan(plan); - plan = NULL; + plan = nullptr; } } From 9ad3c19db0f8a0edcf3bd58d8a656a0434a66e24 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 21:40:50 +0800 Subject: [PATCH 10/27] add the func in test --- source/module_base/CMakeLists.txt | 2 +- source/module_base/module_fft/fft_base.h | 4 +- source/module_base/module_fft/fft_cpu.cpp | 8 +- source/module_base/module_fft/fft_cpu.h | 4 +- .../module_base/module_fft/fft_cpu_float.cpp | 221 +++++++++--------- source/module_base/module_fft/fft_temp.cpp | 39 ++-- source/module_basis/module_pw/pw_basis.cpp | 15 +- source/module_basis/module_pw/pw_basis_k.cpp | 11 +- .../module_basis/module_pw/pw_basis_sup.cpp | 26 +++ .../module_pw/test/CMakeLists.txt | 2 +- .../module_basis/module_pw/test/test1-4.cpp | 12 +- source/module_esolver/esolver_fp.cpp | 2 + source/module_esolver/esolver_ks.cpp | 2 +- 13 files changed, 194 insertions(+), 154 deletions(-) diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index e257600ad2..e97f4a7ed6 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -8,7 +8,7 @@ list (APPEND LIBM_SRC endif() if (ENABLE_FLOAT_FFTW) list (APPEND FFT_SRC - module_fft/fftw_float.cpp + module_fft/fft_cpu_float.cpp ) endif() add_library( diff --git a/source/module_base/module_fft/fft_base.h b/source/module_base/module_fft/fft_base.h index 6dd3d947f2..2850e6e719 100644 --- a/source/module_base/module_fft/fft_base.h +++ b/source/module_base/module_fft/fft_base.h @@ -8,8 +8,8 @@ class FFT_BASE { public: - __attribute__((weak)) FFT_BASE(); - virtual __attribute__((weak)) ~FFT_BASE(); + FFT_BASE(); + virtual ~FFT_BASE(); // init parameters of fft virtual void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, diff --git a/source/module_base/module_fft/fft_cpu.cpp b/source/module_base/module_fft/fft_cpu.cpp index 8cbb78639d..8be178b1b4 100644 --- a/source/module_base/module_fft/fft_cpu.cpp +++ b/source/module_base/module_fft/fft_cpu.cpp @@ -134,7 +134,6 @@ void FFT_CPU::clearfft(fftw_plan& plan) template <> void FFT_CPU::cleanFFT() { - printf("in the double cleanFFT\n"); clearfft(planzfor); clearfft(planzbac); clearfft(planxfor1); @@ -189,6 +188,7 @@ void FFT_CPU::fftxyfor(std::complex* in, std::complex* o if (this->xprime) { fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); + for (int i = 0; i < this->lixy + 1; ++i) { fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); @@ -227,13 +227,13 @@ void FFT_CPU::fftxybac(std::complex* in,std::complex* ou } else { + fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); + fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); + for (int i = 0; i < this->nx; ++i) { fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); } - - fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); - fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); } } template <> diff --git a/source/module_base/module_fft/fft_cpu.h b/source/module_base/module_fft/fft_cpu.h index 981dd1d067..04de5531db 100644 --- a/source/module_base/module_fft/fft_cpu.h +++ b/source/module_base/module_fft/fft_cpu.h @@ -12,8 +12,8 @@ template class FFT_CPU : public FFT_BASE { public: - __attribute__((weak)) FFT_CPU(); - __attribute__((weak)) ~FFT_CPU(); + FFT_CPU(); + ~FFT_CPU(); __attribute__((weak)) void initfftmode(int fft_mode_in); diff --git a/source/module_base/module_fft/fft_cpu_float.cpp b/source/module_base/module_fft/fft_cpu_float.cpp index 25951388bf..ddea6cdd75 100644 --- a/source/module_base/module_fft/fft_cpu_float.cpp +++ b/source/module_base/module_fft/fft_cpu_float.cpp @@ -1,25 +1,18 @@ #include "fft_cpu.h" -#include "fftw3.h" -#if defined(__FFTW3_MPI) && defined(__MPI) -#include -//#include "fftw3-mpi_mkl.h" -#endif + +// #include "fftw3f.h" +// #if defined(__FFTW3_MPI) && defined(__MPI) +// #include "fftw3f-mpi.h" +// //#include "fftw3-mpi_mkl.h" template <> -FFT_CPU::FFT_CPU() -{ - -} -template <> -FFT_CPU::~FFT_CPU() +void FFT_CPU::initfftmode(int fft_mode_in) { - + this->fft_mode = fft_mode_in; } - template <> void FFT_CPU::setupFFT() { - unsigned int flag = FFTW_ESTIMATE; switch (this->fft_mode) { @@ -40,86 +33,91 @@ void FFT_CPU::setupFFT() } if (!this->mpifft) { - z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); - z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); - d_rspace = (float*)z_auxg; - this->planzfor = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, - (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); + c_auxg = (std::complex*)fftwf_malloc(sizeof(fftwf_complex) * this->maxgrids); + c_auxr = (std::complex*)fftwf_malloc(sizeof(fftwf_complex) * this->maxgrids); + s_rspace = (float*)c_auxg; + //--------------------------------------------------------- + // 1 D + //--------------------------------------------------------- - this->planzbac = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, - (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); + // fftw_plan_many_dft(int rank, const int *n, int howmany, + // fftw_complex *in, const int *inembed, int istride, int idist, + // fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned + //flags); + this->planfzfor = fftwf_plan_many_dft(1, &this->nz, this->ns, (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, + (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); + + this->planfzbac = fftwf_plan_many_dft(1, &this->nz, this->ns, (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, + (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); //--------------------------------------------------------- - // 2 D - XY + // 2 D //--------------------------------------------------------- - // 1D+1D is much faster than 2D FFT! - // in-place fft is better for c2c and out-of-place fft is better for c2r + int* embed = nullptr; int npy = this->nplane * this->ny; if (this->xprime) { - this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, - (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_FORWARD, flag); - this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, - (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_BACKWARD, flag); + this->planfyfor = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, nplane, 1, + (fftwf_complex*)c_auxr, embed, nplane, 1, FFTW_FORWARD, flag); + this->planfybac = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, nplane, 1, + (fftwf_complex*)c_auxr, embed, nplane, 1, FFTW_BACKWARD, flag); if (this->gamma_only) { - this->planxr2c = fftw_plan_many_dft_r2c(1, &this->nx, npy, d_rspace, embed, npy, 1, (fftw_complex*)z_auxr, - embed, npy, 1, flag); - this->planxc2r = fftw_plan_many_dft_c2r(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, d_rspace, - embed, npy, 1, flag); + this->planfxr2c = fftwf_plan_many_dft_r2c(1, &this->nx, npy, s_rspace, embed, npy, 1, + (fftwf_complex*)c_auxr, embed, npy, 1, flag); + this->planfxc2r = fftwf_plan_many_dft_c2r(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, + s_rspace, embed, npy, 1, flag); } else { - this->planxfor1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, - (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, - (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planfxfor1 = fftwf_plan_many_dft(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, + (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planfxbac1 = fftwf_plan_many_dft(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, + (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); } } else { - this->planxfor1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, - 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, - 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planfxfor1 = fftwf_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftwf_complex*)c_auxr, embed, + npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planfxbac1 = fftwf_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftwf_complex*)c_auxr, embed, + npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); if (this->gamma_only) { - this->planyr2c = fftw_plan_many_dft_r2c(1, &this->ny, this->nplane, d_rspace, embed, this->nplane, 1, - (fftw_complex*)z_auxr, embed, this->nplane, 1, flag); - this->planyc2r = fftw_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, - this->nplane, 1, d_rspace, embed, this->nplane, 1, flag); + this->planfyr2c = fftwf_plan_many_dft_r2c(1, &this->ny, this->nplane, s_rspace, embed, this->nplane, 1, + (fftwf_complex*)c_auxr, embed, this->nplane, 1, flag); + this->planfyc2r = fftwf_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, + this->nplane, 1, s_rspace, embed, this->nplane, 1, flag); } else { - - this->planxfor2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, - npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, - npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, - 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); - this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, - 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); + this->planfxfor2 + = fftwf_plan_many_dft(1, &this->nx, this->nplane * (this->ny - rixy), (fftwf_complex*)c_auxr, embed, + npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planfxbac2 + = fftwf_plan_many_dft(1, &this->nx, this->nplane * (this->ny - rixy), (fftwf_complex*)c_auxr, embed, + npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planfyfor + = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, this->nplane, 1, + (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); + this->planfybac + = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, this->nplane, 1, + (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); } } - } -#if defined(__FFTW3_MPI) && defined(__MPI) - else - { - // this->initplan_mpi(); - // if (this->precision == "single") { - // this->initplanf_mpi(); - // } - } -#endif + } + #if defined(__FFTW3_MPI) && defined(__MPI) + else + { + // this->initplan_mpi(); + // if (this->precision == "single") { + // this->initplanf_mpi(); + // } + } + #endif return; } -template <> -void FFT_CPU::initfftmode(int fft_mode_in) -{ - this->fft_mode = fft_mode_in; -} template <> void FFT_CPU::clearfft(fftw_plan& plan) @@ -153,33 +151,33 @@ template <> void FFT_CPU::clear() { this->cleanFFT(); - if (z_auxg != nullptr) + if (c_auxg != nullptr) { - fftw_free(z_auxg); - z_auxg = nullptr; + fftw_free(c_auxg); + c_auxg = nullptr; } if (z_auxr != nullptr) { - fftw_free(z_auxr); - z_auxr = nullptr; + fftw_free(c_auxr); + c_auxr = nullptr; } - d_rspace = nullptr; + s_rspace = nullptr; } template <> float* FFT_CPU::get_rspace_data() const { - return d_rspace; + return s_rspace; } template <> std::complex* FFT_CPU::get_auxr_data() const { - return z_auxr; + return c_auxr; } template <> std::complex* FFT_CPU::get_auxg_data() const { - return z_auxg; + return c_auxg; } template <> void FFT_CPU::fftxyfor(std::complex* in, std::complex* out) const @@ -187,63 +185,65 @@ void FFT_CPU::fftxyfor(std::complex* in, std::complex* out) int npy = this->nplane * this->ny; if (this->xprime) { - fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); + fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)in, (fftwf_complex*)out); + for (int i = 0; i < this->lixy + 1; ++i) { - fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); } for (int i = rixy; i < this->nx; ++i) { - fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); } } else { for (int i = 0; i < this->nx; ++i) { - fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); } - fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); - fftw_execute_dft(this->planxfor2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); + fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)in, (fftwf_complex*)out); + fftwf_execute_dft(this->planfxfor2, (fftwf_complex*)&in[rixy * nplane], (fftwf_complex*)&out[rixy * nplane]); } } template <> -void FFT_CPU::fftxybac(std::complex* in,std::complex* out) const +void FFT_CPU::fftxybac(std::complex* in,std::complex * out) const { int npy = this->nplane * this->ny; if (this->xprime) { for (int i = 0; i < this->lixy + 1; ++i) { - fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); } for (int i = rixy; i < this->nx; ++i) { - fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); } - fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); + + fftwf_execute_dft(this->planfxbac1, (fftwf_complex*)in, (fftwf_complex*)out); } else { + fftwf_execute_dft(this->planfxbac1, (fftwf_complex*)in, (fftwf_complex*)out); + fftwf_execute_dft(this->planfxbac2, (fftwf_complex*)&in[rixy * nplane], (fftwf_complex*)&out[rixy * nplane]); + for (int i = 0; i < this->nx; ++i) { - fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); + fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&out[i * npy]); } - - fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); - fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); } } template <> void FFT_CPU::fftzfor(std::complex* in, std::complex* out) const { - fftw_execute_dft(this->planzfor, (fftw_complex*)in, (fftw_complex*)out); + fftwf_execute_dft(this->planfzfor, (fftwf_complex*)in, (fftwf_complex*)out); } template <> void FFT_CPU::fftzbac(std::complex* in, std::complex* out) const { - fftw_execute_dft(this->planzbac, (fftw_complex*)in, (fftw_complex*)out); + fftwf_execute_dft(this->planfzbac, (fftwf_complex*)in, (fftwf_complex*)out); } template <> void FFT_CPU::fftxyr2c(float* in, std::complex* out) const @@ -251,56 +251,47 @@ void FFT_CPU::fftxyr2c(float* in, std::complex* out) const int npy = this->nplane * this->ny; if (this->xprime) { - fftw_execute_dft_r2c(this->planxr2c, in, (fftw_complex*)out); + fftwf_execute_dft_r2c(this->planfxr2c, in, (fftwf_complex*)out); for (int i = 0; i < this->lixy + 1; ++i) { - fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]); + fftwf_execute_dft(this->planfyfor, (fftwf_complex*)&out[i * npy], (fftwf_complex*)&out[i * npy]); } } else { for (int i = 0; i < this->nx; ++i) { - fftw_execute_dft_r2c(this->planyr2c, &in[i * npy], (fftw_complex*)&out[i * npy]); + fftwf_execute_dft_r2c(this->planfyr2c, &in[i * npy], (fftwf_complex*)&out[i * npy]); } - fftw_execute_dft(this->planxfor1, (fftw_complex*)out, (fftw_complex*)out); + fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)out, (fftwf_complex*)out); } } - template <> -void FFT_CPU::fftxyc2r(std::complex *in,float *out) const +void FFT_CPU::fftxyc2r(std::complex* in, float* out) const { int npy = this->nplane * this->ny; if (this->xprime) { for (int i = 0; i < this->lixy + 1; ++i) { - fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&in[i * npy]); + fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&in[i * npy]); } - fftw_execute_dft_c2r(this->planxc2r, (fftw_complex*)in, out); + fftwf_execute_dft_c2r(this->planfxc2r, (fftwf_complex*)in, out); } else { - fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)in); + fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)in, (fftwf_complex*)in); for (int i = 0; i < this->nx; ++i) { - fftw_execute_dft_c2r(this->planyc2r, (fftw_complex*)&in[i * npy], &out[i * npy]); + fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&in[i * npy]); } - } -} -template <> -FFT_CPU::FFT_CPU() -{ - -} -template <> -FFT_CPU::~FFT_CPU() -{ + fftwf_execute_dft_c2r(this->planfyc2r, (fftwf_complex*)in, out); + } } -template FFT_CPU::FFT_CPU(); -template FFT_CPU::FFT_CPU(); +// template FFT_CPU::~FFT_CPU(); +// template FFT_CPU::FFT_CPU(); \ No newline at end of file diff --git a/source/module_base/module_fft/fft_temp.cpp b/source/module_base/module_fft/fft_temp.cpp index 4301fac57f..6246d30256 100644 --- a/source/module_base/module_fft/fft_temp.cpp +++ b/source/module_base/module_fft/fft_temp.cpp @@ -40,16 +40,6 @@ FFT_TEMP::FFT_TEMP(std::string device_in,std::string precision_in) FFT_TEMP::~FFT_TEMP() { - if (float_flag) - { - delete fft_float; - fft_float=nullptr; - } - if (double_flag) - { - delete fft_double; - fft_double=nullptr; - } } void FFT_TEMP::set_device(std::string device_in) @@ -67,6 +57,11 @@ void FFT_TEMP::setfft(std::string device_in,std::string precision_in) assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); this->device = device_in; this->precision = precision_in; + +} +void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) +{ if (device=="cpu") { fft_float = new FFT_CPU(); @@ -82,22 +77,18 @@ void FFT_TEMP::setfft(std::string device_in,std::string precision_in) // fft_double = new FFT_CUDA(); // #endif } -} -void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, - int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) -{ if (this->precision=="single") { - float_flag = 1; + float_flag = true; } else if (this->precision=="double") { - double_flag = 1; + double_flag = true; } else if (this->precision=="mixing") { - float_flag = 1; - double_flag = 1; + float_flag = true; + double_flag = true; } if (float_flag) { @@ -149,6 +140,18 @@ void FFT_TEMP::clear() { fft_double->clear(); } + if (fft_float!=nullptr) + { + delete fft_float; + fft_float=nullptr; + float_flag = false; + } + if (fft_double!=nullptr) + { + delete fft_double; + fft_double=nullptr; + double_flag = false; + } } // access the real space data template <> diff --git a/source/module_basis/module_pw/pw_basis.cpp b/source/module_basis/module_pw/pw_basis.cpp index 0121eef9e4..17dfa6b90e 100644 --- a/source/module_basis/module_pw/pw_basis.cpp +++ b/source/module_basis/module_pw/pw_basis.cpp @@ -17,6 +17,7 @@ PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::mo classname="PW_Basis"; this->ft.set_device(this->device); this->ft.set_precision(this->precision); + this->ft1.setfft(this->device,this->precision); } PW_Basis:: ~PW_Basis() @@ -57,9 +58,19 @@ void PW_Basis::setuptransform() this->distribute_g(); this->getstartgr(); this->ft.clear(); - if(this->xprime) this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); - else this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->ft1.clear(); + if(this->xprime) + { + this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->ft1.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + } + else + { + this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->ft1.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + } this->ft.setupFFT(); + this->ft1.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index c71163eba2..9b3257ce71 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -180,9 +180,16 @@ void PW_Basis_K::setuptransform() this->getstartgr(); this->setupIndGk(); this->ft.clear(); - if(this->xprime) this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); - else this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->ft1.clear(); + if(this->xprime){ + this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->ft1.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + }else{ + this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->ft1.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + } this->ft.setupFFT(); + this->ft1.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } diff --git a/source/module_basis/module_pw/pw_basis_sup.cpp b/source/module_basis/module_pw/pw_basis_sup.cpp index 6763f94c07..e174ce2ed4 100644 --- a/source/module_basis/module_pw/pw_basis_sup.cpp +++ b/source/module_basis/module_pw/pw_basis_sup.cpp @@ -20,7 +20,9 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->distribute_g(pw_rho); this->getstartgr(); this->ft.clear(); + this->ft1.clear(); if (this->xprime) + { this->ft.initfft(this->nx, this->ny, this->nz, @@ -31,7 +33,19 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->poolnproc, this->gamma_only, this->xprime); + this->ft1.initfft(this->nx, + this->ny, + this->nz, + this->lix, + this->rix, + this->nst, + this->nplane, + this->poolnproc, + this->gamma_only, + this->xprime); + } else + { this->ft.initfft(this->nx, this->ny, this->nz, @@ -42,7 +56,19 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->poolnproc, this->gamma_only, this->xprime); + this->ft1.initfft(this->nx, + this->ny, + this->nz, + this->liy, + this->riy, + this->nst, + this->nplane, + this->poolnproc, + this->gamma_only, + this->xprime); + } this->ft.setupFFT(); + this->ft1.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } diff --git a/source/module_basis/module_pw/test/CMakeLists.txt b/source/module_basis/module_pw/test/CMakeLists.txt index a5f40f6127..8bf43bbc2d 100644 --- a/source/module_basis/module_pw/test/CMakeLists.txt +++ b/source/module_basis/module_pw/test/CMakeLists.txt @@ -5,7 +5,7 @@ AddTest( SOURCES ../../../module_base/matrix.cpp ../../../module_base/complexmatrix.cpp ../../../module_base/matrix3.cpp ../../../module_base/tool_quit.cpp ../../../module_base/mymath.cpp ../../../module_base/timer.cpp ../../../module_base/memory.cpp ../../../module_base/blas_connector.cpp ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp - ../../../module_base/module_fft/fft_base.cpp ../../../module_base/module_fft/fft_temp.cpp + ../../../module_base/module_fft/fft_base.cpp ../../../module_base/module_fft/fft_temp.cpp ../../../module_base/module_fft/fft_cpu.cpp # ../../../module_psi/kernels/psi_memory_op.cpp ../../../module_base/module_device/memory_op.cpp depend_mock.cpp pw_test.cpp test1-1-1.cpp test1-1-2.cpp test1-2.cpp test1-3.cpp test1-4.cpp test1-5.cpp diff --git a/source/module_basis/module_pw/test/test1-4.cpp b/source/module_basis/module_pw/test/test1-4.cpp index 4ea6ec5c9e..9503326ad3 100644 --- a/source/module_basis/module_pw/test/test1-4.cpp +++ b/source/module_basis/module_pw/test/test1-4.cpp @@ -156,8 +156,8 @@ TEST_F(PWTEST,test1_4) { EXPECT_NEAR(tmp[ixy * nz + startiz + iz].real(),rhor[ixy*nplane+iz].real(),1e-6); EXPECT_NEAR(tmp[ixy * nz + startiz + iz].imag(),rhor[ixy*nplane+iz].imag(),1e-6); - EXPECT_NEAR(tmp[ixy * nz + startiz + iz].real(),rhogr[ixy*nplane+iz].real(),1e-6); - EXPECT_NEAR(tmp[ixy * nz + startiz + iz].imag(),rhogr[ixy*nplane+iz].imag(),1e-6); + // EXPECT_NEAR(tmp[ixy * nz + startiz + iz].real(),rhogr[ixy*nplane+iz].real(),1e-6); + // EXPECT_NEAR(tmp[ixy * nz + startiz + iz].imag(),rhogr[ixy*nplane+iz].imag(),1e-6); #ifdef __ENABLE_FLOAT_FFTW EXPECT_NEAR(tmp[ixy * nz + startiz + iz].real(),rhofr[ixy*nplane+iz].real(),1e-4); EXPECT_NEAR(tmp[ixy * nz + startiz + iz].imag(),rhofr[ixy*nplane+iz].imag(),1e-4); @@ -178,10 +178,10 @@ TEST_F(PWTEST,test1_4) for(int ig = 0 ; ig < npwk ; ++ig) { - EXPECT_NEAR(rhog[ig].real(),rhogout[ig].real(),1e-6); - EXPECT_NEAR(rhog[ig].imag(),rhogout[ig].imag(),1e-6); - EXPECT_NEAR(rhog[ig].real(),rhogr[ig].real(),1e-6); - EXPECT_NEAR(rhog[ig].imag(),rhogr[ig].imag(),1e-6); + // EXPECT_NEAR(rhog[ig].real(),rhogout[ig].real(),1e-6); + // EXPECT_NEAR(rhog[ig].imag(),rhogout[ig].imag(),1e-6); + // EXPECT_NEAR(rhog[ig].real(),rhogr[ig].real(),1e-6); + // EXPECT_NEAR(rhog[ig].imag(),rhogr[ig].imag(),1e-6); #ifdef __ENABLE_FLOAT_FFTW EXPECT_NEAR(rhofg[ig].real(),rhofgout[ig].real(),1e-4); EXPECT_NEAR(rhofg[ig].imag(),rhofgout[ig].imag(),1e-4); diff --git a/source/module_esolver/esolver_fp.cpp b/source/module_esolver/esolver_fp.cpp index fa7b6615bd..4599b9f51b 100644 --- a/source/module_esolver/esolver_fp.cpp +++ b/source/module_esolver/esolver_fp.cpp @@ -82,6 +82,7 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell) this->pw_rho->initparameters(false, 4.0 * inp.ecutwfc); this->pw_rho->ft.fft_mode = inp.fft_mode; + this->pw_rho->ft.fft_mode = inp.fft_mode; this->pw_rho->setuptransform(); this->pw_rho->collect_local_pw(); this->pw_rho->collect_uniqgg(); @@ -107,6 +108,7 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell) } this->pw_rhod->initparameters(false, inp.ecutrho); this->pw_rhod->ft.fft_mode = inp.fft_mode; + this->pw_rhod->ft1.initfftmode(inp.fft_mode); pw_rhod_sup->setuptransform(this->pw_rho); this->pw_rhod->collect_local_pw(); this->pw_rhod->collect_uniqgg(); diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index bbbe1bb4ad..9361533e7d 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -252,7 +252,7 @@ void ESolver_KS::before_all_runners(const Input_para& inp, UnitCell& #endif this->pw_wfc->ft.fft_mode = inp.fft_mode; - + this->pw_wfc->ft1.initfftmode(inp.fft_mode); this->pw_wfc->setuptransform(); //! 9) initialize the number of plane waves for each k point From dfaad66ef8b0208a88485489b3cd4ee534424748 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 21:47:09 +0800 Subject: [PATCH 11/27] add the float fft --- source/module_base/module_fft/fft_cpu.cpp | 24 +++++++++---------- .../module_base/module_fft/fft_cpu_float.cpp | 2 -- source/module_base/module_fft/fft_temp.cpp | 2 -- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/source/module_base/module_fft/fft_cpu.cpp b/source/module_base/module_fft/fft_cpu.cpp index 8be178b1b4..f74f1a5667 100644 --- a/source/module_base/module_fft/fft_cpu.cpp +++ b/source/module_base/module_fft/fft_cpu.cpp @@ -7,15 +7,20 @@ template <> FFT_CPU::FFT_CPU() -{ - +{ } template <> FFT_CPU::~FFT_CPU() { - } - +template <> +FFT_CPU::FFT_CPU() +{ +} +template <> +FFT_CPU::~FFT_CPU() +{ +} template <> void FFT_CPU::setupFFT() { @@ -293,15 +298,8 @@ void FFT_CPU::fftxyc2r(std::complex *in,double *out) const } } } -template <> -FFT_CPU::FFT_CPU() -{ - -} -template <> -FFT_CPU::~FFT_CPU() -{ -} template FFT_CPU::FFT_CPU(); +template FFT_CPU::~FFT_CPU(); template FFT_CPU::FFT_CPU(); +template FFT_CPU::~FFT_CPU(); diff --git a/source/module_base/module_fft/fft_cpu_float.cpp b/source/module_base/module_fft/fft_cpu_float.cpp index ddea6cdd75..7922e265a7 100644 --- a/source/module_base/module_fft/fft_cpu_float.cpp +++ b/source/module_base/module_fft/fft_cpu_float.cpp @@ -293,5 +293,3 @@ void FFT_CPU::fftxyc2r(std::complex* in, float* out) const fftwf_execute_dft_c2r(this->planfyc2r, (fftwf_complex*)in, out); } } -// template FFT_CPU::~FFT_CPU(); -// template FFT_CPU::FFT_CPU(); \ No newline at end of file diff --git a/source/module_base/module_fft/fft_temp.cpp b/source/module_base/module_fft/fft_temp.cpp index 6246d30256..99ac42fa07 100644 --- a/source/module_base/module_fft/fft_temp.cpp +++ b/source/module_base/module_fft/fft_temp.cpp @@ -12,8 +12,6 @@ // #include "fft_gpu.h" FFT_TEMP::FFT_TEMP() { - fft_float = nullptr; - fft_double = nullptr; } FFT_TEMP::FFT_TEMP(std::string device_in,std::string precision_in) { From b40629d48773aacd156d055827b6dd153e5a77e5 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 22:07:37 +0800 Subject: [PATCH 12/27] change ft into ft1 --- source/module_base/module_fft/fft_temp.cpp | 6 +- .../module_basis/module_pw/pw_transform.cpp | 62 +++++++++---------- .../module_basis/module_pw/pw_transform_k.cpp | 48 +++++++------- 3 files changed, 56 insertions(+), 60 deletions(-) diff --git a/source/module_base/module_fft/fft_temp.cpp b/source/module_base/module_fft/fft_temp.cpp index 99ac42fa07..861c619233 100644 --- a/source/module_base/module_fft/fft_temp.cpp +++ b/source/module_base/module_fft/fft_temp.cpp @@ -78,14 +78,10 @@ void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in if (this->precision=="single") { float_flag = true; - } - else if (this->precision=="double") - { double_flag = true; } - else if (this->precision=="mixing") + else if (this->precision=="double") { - float_flag = true; double_flag = true; } if (float_flag) diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp index 43d86c4381..d044f3a6ef 100644 --- a/source/module_basis/module_pw/pw_transform.cpp +++ b/source/module_basis/module_pw/pw_transform.cpp @@ -29,13 +29,13 @@ void PW_Basis::real2recip(const std::complex* in, #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - this->ft.get_auxr_data()[ir] = in[ir]; + this->ft1.get_auxr_data()[ir] = in[ir]; } - this->ft.fftxyfor(ft.get_auxr_data(),ft.get_auxr_data()); + this->ft1.fftxyfor(ft1.get_auxr_data(),ft1.get_auxr_data()); - this->gatherp_scatters(this->ft.get_auxr_data(), this->ft.get_auxg_data()); + this->gatherp_scatters(this->ft1.get_auxr_data(), this->ft1.get_auxg_data()); - this->ft.fftzfor(ft.get_auxg_data(),ft.get_auxg_data()); + this->ft1.fftzfor(ft1.get_auxg_data(),ft1.get_auxg_data()); if(add) { @@ -45,7 +45,7 @@ void PW_Basis::real2recip(const std::complex* in, #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] += tmpfac * this->ft.get_auxg_data()[this->ig2isz[ig]]; + out[ig] += tmpfac * this->ft1.get_auxg_data()[this->ig2isz[ig]]; } } else @@ -56,7 +56,7 @@ void PW_Basis::real2recip(const std::complex* in, #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] = tmpfac * this->ft.get_auxg_data()[this->ig2isz[ig]]; + out[ig] = tmpfac * this->ft1.get_auxg_data()[this->ig2isz[ig]]; } } ModuleBase::timer::tick(this->classname, "real2recip"); @@ -83,11 +83,11 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo { for(int ipy = 0 ; ipy < npy ; ++ipy) { - this->ft.get_rspace_data()[ix*npy + ipy] = in[ix*npy + ipy]; + this->ft1.get_rspace_data()[ix*npy + ipy] = in[ix*npy + ipy]; } } - this->ft.fftxyr2c(ft.get_rspace_data(),ft.get_auxr_data()); + this->ft1.fftxyr2c(ft1.get_rspace_data(),ft1.get_auxr_data()); } else { @@ -96,13 +96,13 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - this->ft.get_auxr_data()[ir] = std::complex(in[ir],0); + this->ft1.get_auxr_data()[ir] = std::complex(in[ir],0); } - this->ft.fftxyfor(ft.get_auxr_data(),ft.get_auxr_data()); + this->ft1.fftxyfor(ft1.get_auxr_data(),ft1.get_auxr_data()); } - this->gatherp_scatters(this->ft.get_auxr_data(), this->ft.get_auxg_data()); + this->gatherp_scatters(this->ft1.get_auxr_data(), this->ft1.get_auxg_data()); - this->ft.fftzfor(ft.get_auxg_data(),ft.get_auxg_data()); + this->ft1.fftzfor(ft1.get_auxg_data(),ft1.get_auxg_data()); if(add) { @@ -112,7 +112,7 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] += tmpfac * this->ft.get_auxg_data()[this->ig2isz[ig]]; + out[ig] += tmpfac * this->ft1.get_auxg_data()[this->ig2isz[ig]]; } } else @@ -123,7 +123,7 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] = tmpfac * this->ft.get_auxg_data()[this->ig2isz[ig]]; + out[ig] = tmpfac * this->ft1.get_auxg_data()[this->ig2isz[ig]]; } } ModuleBase::timer::tick(this->classname, "real2recip"); @@ -149,7 +149,7 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int i = 0 ; i < this->nst * this->nz ; ++i) { - ft.get_auxg_data()[i] = std::complex(0, 0); + ft1.get_auxg_data()[i] = std::complex(0, 0); } #ifdef _OPENMP @@ -157,13 +157,13 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - this->ft.get_auxg_data()[this->ig2isz[ig]] = in[ig]; + this->ft1.get_auxg_data()[this->ig2isz[ig]] = in[ig]; } - this->ft.fftzbac(ft.get_auxg_data(), ft.get_auxg_data()); + this->ft1.fftzbac(ft1.get_auxg_data(), ft1.get_auxg_data()); - this->gathers_scatterp(this->ft.get_auxg_data(),this->ft.get_auxr_data()); + this->gathers_scatterp(this->ft1.get_auxg_data(),this->ft1.get_auxr_data()); - this->ft.fftxybac(ft.get_auxr_data(),ft.get_auxr_data()); + this->ft1.fftxybac(ft1.get_auxr_data(),ft1.get_auxr_data()); if(add) { @@ -172,7 +172,7 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] += factor * this->ft.get_auxr_data()[ir]; + out[ir] += factor * this->ft1.get_auxr_data()[ir]; } } else @@ -182,7 +182,7 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] = this->ft.get_auxr_data()[ir]; + out[ir] = this->ft1.get_auxr_data()[ir]; } } ModuleBase::timer::tick(this->classname, "recip2real"); @@ -204,7 +204,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int i = 0 ; i < this->nst * this->nz ; ++i) { - ft.get_auxg_data()[i] = std::complex(0, 0); + ft1.get_auxg_data()[i] = std::complex(0, 0); } #ifdef _OPENMP @@ -212,15 +212,15 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - this->ft.get_auxg_data()[this->ig2isz[ig]] = in[ig]; + this->ft1.get_auxg_data()[this->ig2isz[ig]] = in[ig]; } - this->ft.fftzbac(ft.get_auxg_data(), ft.get_auxg_data()); + this->ft1.fftzbac(ft1.get_auxg_data(), ft1.get_auxg_data()); - this->gathers_scatterp(this->ft.get_auxg_data(), this->ft.get_auxr_data()); + this->gathers_scatterp(this->ft1.get_auxg_data(), this->ft1.get_auxr_data()); if(this->gamma_only) { - this->ft.fftxyc2r(ft.get_auxr_data(),ft.get_rspace_data()); + this->ft1.fftxyc2r(ft1.get_auxr_data(),ft1.get_rspace_data()); // r2c in place const int npy = this->ny * this->nplane; @@ -234,7 +234,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo { for(int ipy = 0 ; ipy < npy ; ++ipy) { - out[ix*npy + ipy] += factor * this->ft.get_rspace_data()[ix*npy + ipy]; + out[ix*npy + ipy] += factor * this->ft1.get_rspace_data()[ix*npy + ipy]; } } } @@ -247,14 +247,14 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo { for(int ipy = 0 ; ipy < npy ; ++ipy) { - out[ix*npy + ipy] = this->ft.get_rspace_data()[ix*npy + ipy]; + out[ix*npy + ipy] = this->ft1.get_rspace_data()[ix*npy + ipy]; } } } } else { - this->ft.fftxybac(ft.get_auxr_data(),ft.get_auxr_data()); + this->ft1.fftxybac(ft1.get_auxr_data(),ft1.get_auxr_data()); if(add) { #ifdef _OPENMP @@ -262,7 +262,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] += factor * this->ft.get_auxr_data()[ir].real(); + out[ir] += factor * this->ft1.get_auxr_data()[ir].real(); } } else @@ -272,7 +272,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] = this->ft.get_auxr_data()[ir].real(); + out[ir] = this->ft1.get_auxr_data()[ir].real(); } } } diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index 0ea362825b..978819501d 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -32,7 +32,7 @@ void PW_Basis_K::real2recip(const std::complex* in, ModuleBase::timer::tick(this->classname, "real2recip"); assert(this->gamma_only == false); - auto* auxr = this->ft.get_auxr_data(); + auto* auxr = this->ft1.get_auxr_data(); #ifdef _OPENMP #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif @@ -40,15 +40,15 @@ void PW_Basis_K::real2recip(const std::complex* in, { auxr[ir] = in[ir]; } - this->ft.fftxyfor(ft.get_auxr_data(), ft.get_auxr_data()); + this->ft1.fftxyfor(ft1.get_auxr_data(), ft1.get_auxr_data()); - this->gatherp_scatters(this->ft.get_auxr_data(), this->ft.get_auxg_data()); + this->gatherp_scatters(this->ft1.get_auxr_data(), this->ft1.get_auxg_data()); - this->ft.fftzfor(ft.get_auxg_data(), ft.get_auxg_data()); + this->ft1.fftzfor(ft1.get_auxg_data(), ft1.get_auxg_data()); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft.get_auxg_data(); + auto* auxg = this->ft1.get_auxg_data(); if (add) { FPTYPE tmpfac = factor / FPTYPE(this->nxyz); @@ -98,7 +98,7 @@ void PW_Basis_K::real2recip(const FPTYPE* in, assert(this->gamma_only == true); // for(int ir = 0 ; ir < this->nrxx ; ++ir) // { - // this->ft.get_rspace_data()[ir] = in[ir]; + // this->ft1.get_rspace_data()[ir] = in[ir]; // } // r2c in place const int npy = this->ny * this->nplane; @@ -109,19 +109,19 @@ void PW_Basis_K::real2recip(const FPTYPE* in, { for (int ipy = 0; ipy < npy; ++ipy) { - this->ft.get_rspace_data()[ix * npy + ipy] = in[ix * npy + ipy]; + this->ft1.get_rspace_data()[ix * npy + ipy] = in[ix * npy + ipy]; } } - this->ft.fftxyr2c(ft.get_rspace_data(), ft.get_auxr_data()); + this->ft1.fftxyr2c(ft1.get_rspace_data(), ft1.get_auxr_data()); - this->gatherp_scatters(this->ft.get_auxr_data(), this->ft.get_auxg_data()); + this->gatherp_scatters(this->ft1.get_auxr_data(), this->ft1.get_auxg_data()); - this->ft.fftzfor(ft.get_auxg_data(), ft.get_auxg_data()); + this->ft1.fftzfor(ft1.get_auxg_data(), ft1.get_auxg_data()); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft.get_auxg_data(); + auto* auxg = this->ft1.get_auxg_data(); if (add) { FPTYPE tmpfac = factor / FPTYPE(this->nxyz); @@ -170,11 +170,11 @@ void PW_Basis_K::recip2real(const std::complex* in, { ModuleBase::timer::tick(this->classname, "recip2real"); assert(this->gamma_only == false); - ModuleBase::GlobalFunc::ZEROS(ft.get_auxg_data(), this->nst * this->nz); + ModuleBase::GlobalFunc::ZEROS(ft1.get_auxg_data(), this->nst * this->nz); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft.get_auxg_data(); + auto* auxg = this->ft1.get_auxg_data(); #ifdef _OPENMP #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif @@ -182,13 +182,13 @@ void PW_Basis_K::recip2real(const std::complex* in, { auxg[this->igl2isz_k[igl + startig]] = in[igl]; } - this->ft.fftzbac(ft.get_auxg_data(), ft.get_auxg_data()); + this->ft1.fftzbac(ft1.get_auxg_data(), ft1.get_auxg_data()); - this->gathers_scatterp(this->ft.get_auxg_data(), this->ft.get_auxr_data()); + this->gathers_scatterp(this->ft1.get_auxg_data(), this->ft1.get_auxr_data()); - this->ft.fftxybac(ft.get_auxr_data(), ft.get_auxr_data()); + this->ft1.fftxybac(ft1.get_auxr_data(), ft1.get_auxr_data()); - auto* auxr = this->ft.get_auxr_data(); + auto* auxr = this->ft1.get_auxr_data(); if (add) { #ifdef _OPENMP @@ -234,11 +234,11 @@ void PW_Basis_K::recip2real(const std::complex* in, { ModuleBase::timer::tick(this->classname, "recip2real"); assert(this->gamma_only == true); - ModuleBase::GlobalFunc::ZEROS(ft.get_auxg_data(), this->nst * this->nz); + ModuleBase::GlobalFunc::ZEROS(ft1.get_auxg_data(), this->nst * this->nz); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft.get_auxg_data(); + auto* auxg = this->ft1.get_auxg_data(); #ifdef _OPENMP #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif @@ -246,20 +246,20 @@ void PW_Basis_K::recip2real(const std::complex* in, { auxg[this->igl2isz_k[igl + startig]] = in[igl]; } - this->ft.fftzbac(ft.get_auxg_data(), ft.get_auxg_data()); + this->ft1.fftzbac(ft1.get_auxg_data(), ft1.get_auxg_data()); - this->gathers_scatterp(this->ft.get_auxg_data(), this->ft.get_auxr_data()); + this->gathers_scatterp(this->ft1.get_auxg_data(), this->ft1.get_auxr_data()); - this->ft.fftxyc2r(ft.get_auxr_data(), ft.get_rspace_data()); + this->ft1.fftxyc2r(ft1.get_auxr_data(), ft1.get_rspace_data()); // for(int ir = 0 ; ir < this->nrxx ; ++ir) // { - // out[ir] = this->ft.get_rspace_data()[ir] / this->nxyz; + // out[ir] = this->ft1.get_rspace_data()[ir] / this->nxyz; // } // r2c in place const int npy = this->ny * this->nplane; - auto* rspace = this->ft.get_rspace_data(); + auto* rspace = this->ft1.get_rspace_data(); if (add) { #ifdef _OPENMP From 965d6275722f4a31e02599d6441037af50d96803 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Tue, 5 Nov 2024 23:25:19 +0800 Subject: [PATCH 13/27] add the file of the float_define and the device set --- source/module_base/module_fft/fft_temp.cpp | 20 +++++++++----------- source/module_base/module_fft/fft_temp.h | 1 + source/module_basis/module_pw/pw_basis.cpp | 2 +- source/module_basis/module_pw/pw_basis_k.cpp | 1 + 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/source/module_base/module_fft/fft_temp.cpp b/source/module_base/module_fft/fft_temp.cpp index 861c619233..6be107eec2 100644 --- a/source/module_base/module_fft/fft_temp.cpp +++ b/source/module_base/module_fft/fft_temp.cpp @@ -65,26 +65,24 @@ void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in fft_float = new FFT_CPU(); fft_double = new FFT_CPU(); } - else if (device=="gpu") - { - // #if defined(__ROCM) - // fft_float = new FFT_RCOM(); - // fft_double = new FFT_RCOM(); - // #elif defined(__CUDA) - // fft_float = new FFT_CUDA(); - // fft_double = new FFT_CUDA(); - // #endif - } + if (this->precision=="single") { float_flag = true; + #ifdef __ENABLE_FLOAT_FFTW + float_define = true; + #endif + float_flag = float_define & float_flag; double_flag = true; + + } else if (this->precision=="double") { double_flag = true; } - if (float_flag) + + if (float_flag && float_define) { fft_float->initfftmode(this->fft_mode); fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); diff --git a/source/module_base/module_fft/fft_temp.h b/source/module_base/module_fft/fft_temp.h index 9d08319efb..058fd94e93 100644 --- a/source/module_base/module_fft/fft_temp.h +++ b/source/module_base/module_fft/fft_temp.h @@ -55,6 +55,7 @@ class FFT_TEMP private: int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive bool float_flag=false; + bool float_define=false; bool double_flag=false; FFT_BASE* fft_float=nullptr; FFT_BASE* fft_double=nullptr; diff --git a/source/module_basis/module_pw/pw_basis.cpp b/source/module_basis/module_pw/pw_basis.cpp index 17dfa6b90e..4bca6baed3 100644 --- a/source/module_basis/module_pw/pw_basis.cpp +++ b/source/module_basis/module_pw/pw_basis.cpp @@ -17,7 +17,7 @@ PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::mo classname="PW_Basis"; this->ft.set_device(this->device); this->ft.set_precision(this->precision); - this->ft1.setfft(this->device,this->precision); + this->ft1.setfft("cpu",this->precision); } PW_Basis:: ~PW_Basis() diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index 9b3257ce71..7f6413cf1f 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -12,6 +12,7 @@ namespace ModulePW PW_Basis_K::PW_Basis_K() { classname="PW_Basis_K"; + this->ft1.setfft("cpu",this->precision); } PW_Basis_K::~PW_Basis_K() { From 39e0d0c9d2c45e1de597459a33f98cf72dc2fde2 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Wed, 6 Nov 2024 08:53:58 +0800 Subject: [PATCH 14/27] delete the memory allocate in the ft --- source/module_basis/module_pw/fft.cpp | 26 +++++++++---------- .../module_basis/module_pw/test/test1-4.cpp | 12 ++++----- .../module_elecstate/module_charge/charge.cpp | 4 +-- .../module_charge/charge_init.cpp | 4 +-- source/module_esolver/esolver_fp.cpp | 2 +- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/source/module_basis/module_pw/fft.cpp b/source/module_basis/module_pw/fft.cpp index 1c56f9b5af..7b2ad09979 100644 --- a/source/module_basis/module_pw/fft.cpp +++ b/source/module_basis/module_pw/fft.cpp @@ -92,10 +92,10 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int int maxgrids = (nsz > nrxx) ? nsz : nrxx; if (!this->mpifft) { - z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); - z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); - ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids); - d_rspace = (double*)z_auxg; + // z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); + // z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * maxgrids); + // ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids); + // d_rspace = (double*)z_auxg; // auxr_3d = static_cast *>( // fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz))); #if defined(__CUDA) || defined(__ROCM) @@ -105,15 +105,15 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz); } #endif // defined(__CUDA) || defined(__ROCM) -#if defined(__ENABLE_FLOAT_FFTW) - if (this->precision == "single") - { - c_auxg = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); - c_auxr = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); - ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids); - s_rspace = (float*)c_auxg; - } -#endif // defined(__ENABLE_FLOAT_FFTW) +// #if defined(__ENABLE_FLOAT_FFTW) +// if (this->precision == "single") +// { +// c_auxg = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); +// c_auxr = (std::complex*)fftw_malloc(sizeof(fftwf_complex) * maxgrids); +// ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids); +// s_rspace = (float*)c_auxg; +// } +// #endif // defined(__ENABLE_FLOAT_FFTW) } else { diff --git a/source/module_basis/module_pw/test/test1-4.cpp b/source/module_basis/module_pw/test/test1-4.cpp index 9503326ad3..4ea6ec5c9e 100644 --- a/source/module_basis/module_pw/test/test1-4.cpp +++ b/source/module_basis/module_pw/test/test1-4.cpp @@ -156,8 +156,8 @@ TEST_F(PWTEST,test1_4) { EXPECT_NEAR(tmp[ixy * nz + startiz + iz].real(),rhor[ixy*nplane+iz].real(),1e-6); EXPECT_NEAR(tmp[ixy * nz + startiz + iz].imag(),rhor[ixy*nplane+iz].imag(),1e-6); - // EXPECT_NEAR(tmp[ixy * nz + startiz + iz].real(),rhogr[ixy*nplane+iz].real(),1e-6); - // EXPECT_NEAR(tmp[ixy * nz + startiz + iz].imag(),rhogr[ixy*nplane+iz].imag(),1e-6); + EXPECT_NEAR(tmp[ixy * nz + startiz + iz].real(),rhogr[ixy*nplane+iz].real(),1e-6); + EXPECT_NEAR(tmp[ixy * nz + startiz + iz].imag(),rhogr[ixy*nplane+iz].imag(),1e-6); #ifdef __ENABLE_FLOAT_FFTW EXPECT_NEAR(tmp[ixy * nz + startiz + iz].real(),rhofr[ixy*nplane+iz].real(),1e-4); EXPECT_NEAR(tmp[ixy * nz + startiz + iz].imag(),rhofr[ixy*nplane+iz].imag(),1e-4); @@ -178,10 +178,10 @@ TEST_F(PWTEST,test1_4) for(int ig = 0 ; ig < npwk ; ++ig) { - // EXPECT_NEAR(rhog[ig].real(),rhogout[ig].real(),1e-6); - // EXPECT_NEAR(rhog[ig].imag(),rhogout[ig].imag(),1e-6); - // EXPECT_NEAR(rhog[ig].real(),rhogr[ig].real(),1e-6); - // EXPECT_NEAR(rhog[ig].imag(),rhogr[ig].imag(),1e-6); + EXPECT_NEAR(rhog[ig].real(),rhogout[ig].real(),1e-6); + EXPECT_NEAR(rhog[ig].imag(),rhogout[ig].imag(),1e-6); + EXPECT_NEAR(rhog[ig].real(),rhogr[ig].real(),1e-6); + EXPECT_NEAR(rhog[ig].imag(),rhogr[ig].imag(),1e-6); #ifdef __ENABLE_FLOAT_FFTW EXPECT_NEAR(rhofg[ig].real(),rhofgout[ig].real(),1e-4); EXPECT_NEAR(rhofg[ig].imag(),rhofgout[ig].imag(),1e-4); diff --git a/source/module_elecstate/module_charge/charge.cpp b/source/module_elecstate/module_charge/charge.cpp index 9003844ca1..dec33f0418 100644 --- a/source/module_elecstate/module_charge/charge.cpp +++ b/source/module_elecstate/module_charge/charge.cpp @@ -644,10 +644,10 @@ void Charge::atomic_rho(const int spin_number_need, double sumrea = 0.0; for (int ir = 0; ir < this->rhopw->nrxx; ir++) { - rea = this->rhopw->ft.get_auxr_data()[ir].real(); + rea = this->rhopw->ft1.get_auxr_data()[ir].real(); sumrea += rea; neg += std::min(0.0, rea); - ima += std::abs(this->rhopw->ft.get_auxr_data()[ir].imag()); + ima += std::abs(this->rhopw->ft1.get_auxr_data()[ir].imag()); } #ifdef __MPI diff --git a/source/module_elecstate/module_charge/charge_init.cpp b/source/module_elecstate/module_charge/charge_init.cpp index c0806da937..9efca214f9 100644 --- a/source/module_elecstate/module_charge/charge_init.cpp +++ b/source/module_elecstate/module_charge/charge_init.cpp @@ -260,8 +260,8 @@ void Charge::set_rho_core( double rhoneg = 0.0; for (int ir = 0; ir < this->rhopw->nrxx; ir++) { - rhoneg += std::min(0.0, this->rhopw->ft.get_auxr_data()[ir].real()); - rhoima += std::abs(this->rhopw->ft.get_auxr_data()[ir].imag()); + rhoneg += std::min(0.0, this->rhopw->ft1.get_auxr_data()[ir].real()); + rhoima += std::abs(this->rhopw->ft1.get_auxr_data()[ir].imag()); // NOTE: Core charge is computed in reciprocal space and brought to real // space by FFT. For non smooth core charges (or insufficient cut-off) // this may result in negative values in some grid points. diff --git a/source/module_esolver/esolver_fp.cpp b/source/module_esolver/esolver_fp.cpp index 4599b9f51b..ead9fdb71b 100644 --- a/source/module_esolver/esolver_fp.cpp +++ b/source/module_esolver/esolver_fp.cpp @@ -82,7 +82,7 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell) this->pw_rho->initparameters(false, 4.0 * inp.ecutwfc); this->pw_rho->ft.fft_mode = inp.fft_mode; - this->pw_rho->ft.fft_mode = inp.fft_mode; + this->pw_rho->ft1.initfftmode(inp.fft_mode); this->pw_rho->setuptransform(); this->pw_rho->collect_local_pw(); this->pw_rho->collect_uniqgg(); From b183967a0e03c4dee18691b302913fb2eda06c11 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 01:32:47 +0000 Subject: [PATCH 15/27] [pre-commit.ci lite] apply automatic fixes --- source/module_basis/module_pw/fft.cpp | 29 ++++++++++++++------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/source/module_basis/module_pw/fft.cpp b/source/module_basis/module_pw/fft.cpp index 7b2ad09979..fa94bd6442 100644 --- a/source/module_basis/module_pw/fft.cpp +++ b/source/module_basis/module_pw/fft.cpp @@ -72,10 +72,11 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int this->fftny = this->ny = ny_in; if (this->gamma_only) { - if (xprime) + if (xprime) { this->fftnx = int(nx / 2) + 1; - else + } else { this->fftny = int(ny / 2) + 1; +} } this->nz = nz_in; this->ns = ns_in; @@ -353,62 +354,62 @@ void FFT::cleanFFT() if (planzfor) { fftw_destroy_plan(planzfor); - planzfor = NULL; + planzfor = nullptr; } if (planzbac) { fftw_destroy_plan(planzbac); - planzbac = NULL; + planzbac = nullptr; } if (planxfor1) { fftw_destroy_plan(planxfor1); - planxfor1 = NULL; + planxfor1 = nullptr; } if (planxbac1) { fftw_destroy_plan(planxbac1); - planxbac1 = NULL; + planxbac1 = nullptr; } if (planxfor2) { fftw_destroy_plan(planxfor2); - planxfor2 = NULL; + planxfor2 = nullptr; } if (planxbac2) { fftw_destroy_plan(planxbac2); - planxbac2 = NULL; + planxbac2 = nullptr; } if (planyfor) { fftw_destroy_plan(planyfor); - planyfor = NULL; + planyfor = nullptr; } if (planybac) { fftw_destroy_plan(planybac); - planybac = NULL; + planybac = nullptr; } if (planxr2c) { fftw_destroy_plan(planxr2c); - planxr2c = NULL; + planxr2c = nullptr; } if (planxc2r) { fftw_destroy_plan(planxc2r); - planxc2r = NULL; + planxc2r = nullptr; } if (planyr2c) { fftw_destroy_plan(planyr2c); - planyr2c = NULL; + planyr2c = nullptr; } if (planyc2r) { fftw_destroy_plan(planyc2r); - planyc2r = NULL; + planyc2r = nullptr; } // fftw_destroy_plan(this->plan3dforward); From 2a45b32564220792fb72ff9bc12bb9ae94b7c5af Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Thu, 7 Nov 2024 13:36:10 +0800 Subject: [PATCH 16/27] add the Smart Pointer and the logic gate --- source/module_base/module_fft/fft_temp.cpp | 59 +++++++++++++-------- source/module_base/module_fft/fft_temp.h | 6 ++- source/module_basis/module_pw/test/Makefile | 11 ++-- 3 files changed, 49 insertions(+), 27 deletions(-) diff --git a/source/module_base/module_fft/fft_temp.cpp b/source/module_base/module_fft/fft_temp.cpp index 6be107eec2..8b83a3f8b0 100644 --- a/source/module_base/module_fft/fft_temp.cpp +++ b/source/module_base/module_fft/fft_temp.cpp @@ -9,6 +9,11 @@ // #include "fft_rcom.h" // #endif +template +std::unique_ptr make_unique(Args &&... args) +{ + return std::unique_ptr(new FFT_BASE(std::forward(args)...)); +} // #include "fft_gpu.h" FFT_TEMP::FFT_TEMP() { @@ -21,8 +26,8 @@ FFT_TEMP::FFT_TEMP(std::string device_in,std::string precision_in) this->precision = precision_in; if (device=="cpu") { - fft_float = new FFT_CPU(); - fft_double = new FFT_CPU(); + fft_float = make_unique>(); + fft_double = make_unique>(); } // else if (device=="gpu") // { @@ -62,17 +67,29 @@ void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in { if (device=="cpu") { - fft_float = new FFT_CPU(); - fft_double = new FFT_CPU(); + fft_float = make_unique>(); + fft_double = make_unique>(); + // fft_double = new FFT_CPU(); + #ifndef __ENABLE_FLOAT_FFTW + float_define = false; + #endif + } + if (device=="gpu") + { + // #if defined(__ROCM) + // fft_float = new FFT_RCOM(); + // fft_double = new FFT_RCOM(); + // #elif defined(__CUDA) + // fft_float = make_unique>(); + // fft_double = make_unique>(); + // #endif } - if (this->precision=="single") { - float_flag = true; - #ifdef __ENABLE_FLOAT_FFTW - float_define = true; + #ifndef __ENABLE_FLOAT_FFTW + float_define = false; #endif - float_flag = float_define & float_flag; + float_flag = float_flag; double_flag = true; @@ -132,18 +149,18 @@ void FFT_TEMP::clear() { fft_double->clear(); } - if (fft_float!=nullptr) - { - delete fft_float; - fft_float=nullptr; - float_flag = false; - } - if (fft_double!=nullptr) - { - delete fft_double; - fft_double=nullptr; - double_flag = false; - } + // if (fft_float!=nullptr) + // { + // delete fft_float; + // fft_float=nullptr; + // float_flag = false; + // } + // if (fft_double!=nullptr) + // { + // delete fft_double; + // fft_double=nullptr; + // double_flag = false; + // } } // access the real space data template <> diff --git a/source/module_base/module_fft/fft_temp.h b/source/module_base/module_fft/fft_temp.h index 058fd94e93..cc81095b65 100644 --- a/source/module_base/module_fft/fft_temp.h +++ b/source/module_base/module_fft/fft_temp.h @@ -1,4 +1,5 @@ #include "fft_base.h" +#include // #include "module_psi/psi.h" #ifndef FFT_TEMP_H #define FFT_TEMP_H @@ -57,8 +58,9 @@ class FFT_TEMP bool float_flag=false; bool float_define=false; bool double_flag=false; - FFT_BASE* fft_float=nullptr; - FFT_BASE* fft_double=nullptr; + // FFT_BASE* fft_float=nullptr; // Remove the qualified name and use a raw pointer instead + std::shared_ptr> fft_float=nullptr; + std::shared_ptr> fft_double=nullptr; std::string device = "cpu"; std::string precision = "double"; diff --git a/source/module_basis/module_pw/test/Makefile b/source/module_basis/module_pw/test/Makefile index d970a637db..b3b4fea709 100644 --- a/source/module_basis/module_pw/test/Makefile +++ b/source/module_basis/module_pw/test/Makefile @@ -2,7 +2,7 @@ # Please set # e.g. make CXX=mpiicpc or make CXX=icpc #====================================================================== -CXX = mpiicpc +CXX = mpiicpx # mpiicpc: compile intel parallel version # icpc: compile intel sequential version # mpicxx: compile gnu parallel version @@ -25,7 +25,7 @@ GTEST_DIR = /home/qianrui/gnucompile/g_gtest # Compiler information #========================== HONG = -D__NORMAL -INCLUDES = -I. -I../../../ +INCLUDES = -I. -I../../../ -I../../../module_base/module_container LIBS = OPTS = -Ofast -march=native -std=c++11 -m64 ${INCLUDES} OBJ_DIR = obj @@ -94,7 +94,7 @@ endif ##========================== ## GTEST ##========================== -GTESTOPTS = -I${GTEST_DIR}/include -L${GTEST_DIR}/lib -lgtest -lpthread +GTESTOPTS = -I/usr/local/gtest/include -L/home/ubuntu/desktop/github/googletest/lib -lgtest -lpthread @@ -103,6 +103,8 @@ GTESTOPTS = -I${GTEST_DIR}/include -L${GTEST_DIR}/lib -lgtest -lpthread #========================== VPATH=../../../module_base\ ../../../module_base/module_device\ +../../../module_base/module_container/ATen/core\ +../../../module_base/module_container/ATen\ :../ MATH_OBJS0=matrix.o\ @@ -123,7 +125,8 @@ pw_basis_sup.o\ pw_transform_k.o\ memory.o\ memory_op.o\ -depend_mock.o +depend_mock.o\ +tensor.o\ OTHER_OBJS0= From 28419a66d511bd8e5a53b0a3f9af4ca5cb476784 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Thu, 7 Nov 2024 15:33:00 +0800 Subject: [PATCH 17/27] modify the position of the FFT --- examples/scf/pw_Si2/INPUT | 2 +- source/module_base/CMakeLists.txt | 9 --------- source/module_basis/module_pw/CMakeLists.txt | 8 ++++++++ .../module_pw}/module_fft/fft_base.cpp | 5 ++++- .../module_pw}/module_fft/fft_base.h | 3 +++ .../module_pw}/module_fft/fft_cpu.cpp | 4 +++- .../module_pw}/module_fft/fft_cpu.h | 4 +++- .../module_pw}/module_fft/fft_cpu_float.cpp | 4 +++- .../module_pw}/module_fft/fft_temp.cpp | 3 +++ .../module_pw}/module_fft/fft_temp.h | 5 ++++- source/module_basis/module_pw/pw_basis.h | 2 +- source/module_basis/module_pw/pw_transform.cpp | 2 +- source/module_basis/module_pw/test/CMakeLists.txt | 1 - .../module_pw/test_serial/CMakeLists.txt | 3 +++ source/module_elecstate/test/charge_extra_test.cpp | 6 ++++++ source/module_elecstate/test/elecstate_base_test.cpp | 6 ++++++ .../module_xc/test/CMakeLists.txt | 12 ++++++------ source/module_hsolver/test/hsolver_pw_sup.h | 3 ++- 18 files changed, 57 insertions(+), 25 deletions(-) rename source/{module_base => module_basis/module_pw}/module_fft/fft_base.cpp (95%) rename source/{module_base => module_basis/module_pw}/module_fft/fft_base.h (99%) rename source/{module_base => module_basis/module_pw}/module_fft/fft_cpu.cpp (99%) rename source/{module_base => module_basis/module_pw}/module_fft/fft_cpu.h (99%) rename source/{module_base => module_basis/module_pw}/module_fft/fft_cpu_float.cpp (99%) rename source/{module_base => module_basis/module_pw}/module_fft/fft_temp.cpp (99%) rename source/{module_base => module_basis/module_pw}/module_fft/fft_temp.h (97%) diff --git a/examples/scf/pw_Si2/INPUT b/examples/scf/pw_Si2/INPUT index 141c104100..cc4c54fc4c 100644 --- a/examples/scf/pw_Si2/INPUT +++ b/examples/scf/pw_Si2/INPUT @@ -4,7 +4,7 @@ pseudo_dir ../../../tests/PP_ORB symmetry 1 #Parameters (Accuracy) basis_type pw -ecutwfc 60 +ecutwfc 20 scf_thr 1e-7 scf_nmax 100 device cpu diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index e97f4a7ed6..057bd96b28 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -6,11 +6,6 @@ list (APPEND LIBM_SRC libm/sincos.cpp ) endif() -if (ENABLE_FLOAT_FFTW) - list (APPEND FFT_SRC - module_fft/fft_cpu_float.cpp - ) -endif() add_library( base OBJECT @@ -62,11 +57,7 @@ add_library( module_mixing/plain_mixing.cpp module_mixing/pulay_mixing.cpp module_mixing/broyden_mixing.cpp - module_fft/fft_base.cpp - module_fft/fft_temp.cpp - module_fft/fft_cpu.cpp ${LIBM_SRC} - ${FFT_SRC} ) add_subdirectory(module_container) diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index 2b2d897206..cb82b6b2b5 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -1,3 +1,8 @@ +if (ENABLE_FLOAT_FFTW) + list (APPEND FFT_SRC + module_fft/fft_cpu_float.cpp + ) +endif() list(APPEND objects fft.cpp pw_basis.cpp @@ -10,6 +15,9 @@ list(APPEND objects pw_init.cpp pw_transform.cpp pw_transform_k.cpp + module_fft/fft_base.cpp + module_fft/fft_temp.cpp + module_fft/fft_cpu.cpp ) add_library( diff --git a/source/module_base/module_fft/fft_base.cpp b/source/module_basis/module_pw/module_fft/fft_base.cpp similarity index 95% rename from source/module_base/module_fft/fft_base.cpp rename to source/module_basis/module_pw/module_fft/fft_base.cpp index f17a6b0999..1db284a7be 100644 --- a/source/module_base/module_fft/fft_base.cpp +++ b/source/module_basis/module_pw/module_fft/fft_base.cpp @@ -1,4 +1,6 @@ #include "fft_base.h" +namespace ModulePW +{ template FFT_BASE::FFT_BASE() { @@ -41,4 +43,5 @@ void FFT_BASE::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int template FFT_BASE::FFT_BASE(); template FFT_BASE::FFT_BASE(); template FFT_BASE::~FFT_BASE(); -template FFT_BASE::~FFT_BASE(); \ No newline at end of file +template FFT_BASE::~FFT_BASE(); +} \ No newline at end of file diff --git a/source/module_base/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h similarity index 99% rename from source/module_base/module_fft/fft_base.h rename to source/module_basis/module_pw/module_fft/fft_base.h index 2850e6e719..35e6778a42 100644 --- a/source/module_base/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -3,6 +3,8 @@ #include "fftw3.h" #ifndef FFT_BASE_H #define FFT_BASE_H +namespace ModulePW +{ template class FFT_BASE { @@ -77,4 +79,5 @@ class FFT_BASE void set_precision(std::string precision_); }; +} #endif // FFT_BASE_H diff --git a/source/module_base/module_fft/fft_cpu.cpp b/source/module_basis/module_pw/module_fft/fft_cpu.cpp similarity index 99% rename from source/module_base/module_fft/fft_cpu.cpp rename to source/module_basis/module_pw/module_fft/fft_cpu.cpp index f74f1a5667..6ac7c41628 100644 --- a/source/module_base/module_fft/fft_cpu.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu.cpp @@ -4,7 +4,8 @@ #include //#include "fftw3-mpi_mkl.h" #endif - +namespace ModulePW +{ template <> FFT_CPU::FFT_CPU() { @@ -303,3 +304,4 @@ template FFT_CPU::FFT_CPU(); template FFT_CPU::~FFT_CPU(); template FFT_CPU::FFT_CPU(); template FFT_CPU::~FFT_CPU(); +} \ No newline at end of file diff --git a/source/module_base/module_fft/fft_cpu.h b/source/module_basis/module_pw/module_fft/fft_cpu.h similarity index 99% rename from source/module_base/module_fft/fft_cpu.h rename to source/module_basis/module_pw/module_fft/fft_cpu.h index 04de5531db..1da1d53956 100644 --- a/source/module_base/module_fft/fft_cpu.h +++ b/source/module_basis/module_pw/module_fft/fft_cpu.h @@ -7,7 +7,8 @@ // #endif #ifndef FFT_CPU_H #define FFT_CPU_H - +namespace ModulePW +{ template class FFT_CPU : public FFT_BASE { @@ -80,4 +81,5 @@ class FFT_CPU : public FFT_BASE float* s_rspace = nullptr; // real number space for r, [nplane * nx *ny] double* d_rspace = nullptr; // real number space for r, [nplane * nx *ny] }; +} #endif // FFT_CPU_H \ No newline at end of file diff --git a/source/module_base/module_fft/fft_cpu_float.cpp b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp similarity index 99% rename from source/module_base/module_fft/fft_cpu_float.cpp rename to source/module_basis/module_pw/module_fft/fft_cpu_float.cpp index 7922e265a7..8c59b8e42a 100644 --- a/source/module_base/module_fft/fft_cpu_float.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp @@ -4,7 +4,8 @@ // #if defined(__FFTW3_MPI) && defined(__MPI) // #include "fftw3f-mpi.h" // //#include "fftw3-mpi_mkl.h" - +namespace ModulePW +{ template <> void FFT_CPU::initfftmode(int fft_mode_in) { @@ -293,3 +294,4 @@ void FFT_CPU::fftxyc2r(std::complex* in, float* out) const fftwf_execute_dft_c2r(this->planfyc2r, (fftwf_complex*)in, out); } } +} \ No newline at end of file diff --git a/source/module_base/module_fft/fft_temp.cpp b/source/module_basis/module_pw/module_fft/fft_temp.cpp similarity index 99% rename from source/module_base/module_fft/fft_temp.cpp rename to source/module_basis/module_pw/module_fft/fft_temp.cpp index 8b83a3f8b0..f15b6ac4d6 100644 --- a/source/module_base/module_fft/fft_temp.cpp +++ b/source/module_basis/module_pw/module_fft/fft_temp.cpp @@ -14,6 +14,8 @@ std::unique_ptr make_unique(Args &&... args) { return std::unique_ptr(new FFT_BASE(std::forward(args)...)); } +namespace ModulePW +{ // #include "fft_gpu.h" FFT_TEMP::FFT_TEMP() { @@ -290,4 +292,5 @@ template <> void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const { fft_double->fft3D_backward(in, out); +} } \ No newline at end of file diff --git a/source/module_base/module_fft/fft_temp.h b/source/module_basis/module_pw/module_fft/fft_temp.h similarity index 97% rename from source/module_base/module_fft/fft_temp.h rename to source/module_basis/module_pw/module_fft/fft_temp.h index cc81095b65..c1121298fd 100644 --- a/source/module_base/module_fft/fft_temp.h +++ b/source/module_basis/module_pw/module_fft/fft_temp.h @@ -3,6 +3,8 @@ // #include "module_psi/psi.h" #ifndef FFT_TEMP_H #define FFT_TEMP_H +namespace ModulePW +{ class FFT_TEMP { public: @@ -65,5 +67,6 @@ class FFT_TEMP std::string device = "cpu"; std::string precision = "double"; }; +} // namespace ModulePW +#endif // FFT_H -#endif // FFT_H \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index f9966dc78a..df65c8ff1d 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -6,7 +6,7 @@ #include "module_base/vector3.h" #include #include "fft.h" -#include "module_base/module_fft/fft_temp.h" +#include "module_fft/fft_temp.h" #include #ifdef __MPI #include "mpi.h" diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp index d044f3a6ef..12090b6db5 100644 --- a/source/module_basis/module_pw/pw_transform.cpp +++ b/source/module_basis/module_pw/pw_transform.cpp @@ -1,5 +1,5 @@ #include "fft.h" -#include "module_base/module_fft/fft_temp.h" +#include "module_fft/fft_temp.h" #include #include "pw_basis.h" #include diff --git a/source/module_basis/module_pw/test/CMakeLists.txt b/source/module_basis/module_pw/test/CMakeLists.txt index 8bf43bbc2d..8b62dacab0 100644 --- a/source/module_basis/module_pw/test/CMakeLists.txt +++ b/source/module_basis/module_pw/test/CMakeLists.txt @@ -5,7 +5,6 @@ AddTest( SOURCES ../../../module_base/matrix.cpp ../../../module_base/complexmatrix.cpp ../../../module_base/matrix3.cpp ../../../module_base/tool_quit.cpp ../../../module_base/mymath.cpp ../../../module_base/timer.cpp ../../../module_base/memory.cpp ../../../module_base/blas_connector.cpp ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp - ../../../module_base/module_fft/fft_base.cpp ../../../module_base/module_fft/fft_temp.cpp ../../../module_base/module_fft/fft_cpu.cpp # ../../../module_psi/kernels/psi_memory_op.cpp ../../../module_base/module_device/memory_op.cpp depend_mock.cpp pw_test.cpp test1-1-1.cpp test1-1-2.cpp test1-2.cpp test1-3.cpp test1-4.cpp test1-5.cpp diff --git a/source/module_basis/module_pw/test_serial/CMakeLists.txt b/source/module_basis/module_pw/test_serial/CMakeLists.txt index df9ae6a962..0b75a0b0fc 100644 --- a/source/module_basis/module_pw/test_serial/CMakeLists.txt +++ b/source/module_basis/module_pw/test_serial/CMakeLists.txt @@ -10,6 +10,9 @@ add_library( planewave_serial OBJECT ../fft.cpp + ../module_fft/fft_base.cpp + ../module_fft/fft_temp.cpp + ../module_fft/fft_cpu.cpp ../pw_basis.cpp ../pw_basis_k.cpp ../pw_basis_sup.cpp diff --git a/source/module_elecstate/test/charge_extra_test.cpp b/source/module_elecstate/test/charge_extra_test.cpp index fadacdb327..8c93133684 100644 --- a/source/module_elecstate/test/charge_extra_test.cpp +++ b/source/module_elecstate/test/charge_extra_test.cpp @@ -70,6 +70,12 @@ FFT::FFT() FFT::~FFT() { } +FFT_TEMP::FFT_TEMP() +{ +} +FFT_TEMP::~FFT_TEMP() +{ +} void PW_Basis::initgrids(const double lat0_in, const ModuleBase::Matrix3 latvec_in, const double gridecut) { } diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index 4a0950fb30..9f8c302af2 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -56,6 +56,12 @@ ModulePW::FFT::FFT() ModulePW::FFT::~FFT() { } +ModulePW::FFT_TEMP::FFT_TEMP() +{ +} +ModulePW::FFT_TEMP::~FFT_TEMP() +{ +} void ModulePW::PW_Basis::initgrids(double, ModuleBase::Matrix3, double) { } diff --git a/source/module_hamilt_general/module_xc/test/CMakeLists.txt b/source/module_hamilt_general/module_xc/test/CMakeLists.txt index 5a7ed800e3..0b2ea5ec98 100644 --- a/source/module_hamilt_general/module_xc/test/CMakeLists.txt +++ b/source/module_hamilt_general/module_xc/test/CMakeLists.txt @@ -38,9 +38,9 @@ AddTest( ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp ../../../module_base/blas_connector.cpp - ../../../module_base/module_fft/fft_base.cpp - ../../../module_base/module_fft/fft_temp.cpp - ../../../module_base/module_fft/fft_cpu.cpp + ../../../module_basis/module_pw/module_fft/fft_base.cpp + ../../../module_basis/module_pw/module_fft/fft_temp.cpp + ../../../module_basis/module_pw/module_fft/fft_cpu.cpp ) AddTest( @@ -76,7 +76,7 @@ AddTest( ../../../module_base/timer.cpp ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp - ../../../module_base/module_fft/fft_base.cpp - ../../../module_base/module_fft/fft_temp.cpp - ../../../module_base/module_fft/fft_cpu.cpp + ../../../module_basis/module_pw/module_fft/fft_base.cpp + ../../../module_basis/module_pw/module_fft/fft_temp.cpp + ../../../module_basis/module_pw/module_fft/fft_cpu.cpp ) \ No newline at end of file diff --git a/source/module_hsolver/test/hsolver_pw_sup.h b/source/module_hsolver/test/hsolver_pw_sup.h index 0fc0e72eaa..28c8fd1e43 100644 --- a/source/module_hsolver/test/hsolver_pw_sup.h +++ b/source/module_hsolver/test/hsolver_pw_sup.h @@ -4,7 +4,8 @@ namespace ModulePW { PW_Basis::PW_Basis(){}; PW_Basis::~PW_Basis(){}; - +FFT_TEMP::FFT_TEMP(){}; +FFT_TEMP::~FFT_TEMP(){}; void PW_Basis::initgrids( const double lat0_in, // unit length (unit in bohr) const ModuleBase::Matrix3 From d29b3555dc25cd974d2f8c084b580264f4504bb7 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Thu, 7 Nov 2024 16:08:22 +0800 Subject: [PATCH 18/27] change fft_bundle name --- source/Makefile.Objects | 8 +-- source/module_basis/module_pw/CMakeLists.txt | 2 +- .../{fft_temp.cpp => fft_bundle.cpp} | 72 +++++++++---------- .../module_fft/{fft_temp.h => fft_bundle.h} | 8 +-- source/module_basis/module_pw/pw_basis.h | 4 +- .../module_basis/module_pw/pw_transform.cpp | 2 +- .../module_pw/test_serial/CMakeLists.txt | 2 +- .../test/charge_extra_test.cpp | 4 +- .../test/elecstate_base_test.cpp | 4 +- source/module_hsolver/test/hsolver_pw_sup.h | 4 +- 10 files changed, 55 insertions(+), 55 deletions(-) rename source/module_basis/module_pw/module_fft/{fft_temp.cpp => fft_bundle.cpp} (66%) rename source/module_basis/module_pw/module_fft/{fft_temp.h => fft_bundle.h} (95%) diff --git a/source/Makefile.Objects b/source/Makefile.Objects index bbfd623cd9..c0ef4fa803 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -25,9 +25,9 @@ VPATH=./src_global:\ ./module_base/module_container/ATen/ops:\ ./module_base/module_device:\ ./module_base/module_mixing:\ -./module_base/module_fft:\ ./module_md:\ ./module_basis/module_pw:\ +./module_basis/module_pw/module_fft:\ ./module_esolver:\ ./module_hsolver:\ ./module_hsolver/kernels:\ @@ -167,9 +167,6 @@ OBJS_BASE=abfs-vector3_order.o\ broyden_mixing.o\ memory_op.o\ device.o\ - fft_temp.o\ - fft_base.o\ - fft_cpu.o\ OBJS_CELL=atom_pseudo.o\ atom_spec.o\ @@ -414,6 +411,9 @@ OBJS_PSI_INITIALIZER=psi_initializer.o\ psi_initializer_nao_random.o\ OBJS_PW=fft.o\ + fft_bundle.o\ + fft_base.o\ + fft_cpu.o\ pw_basis.o\ pw_basis_k.o\ pw_basis_sup.o\ diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index cb82b6b2b5..a5b50e59b7 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -16,7 +16,7 @@ list(APPEND objects pw_transform.cpp pw_transform_k.cpp module_fft/fft_base.cpp - module_fft/fft_temp.cpp + module_fft/fft_bundle.cpp module_fft/fft_cpu.cpp ) diff --git a/source/module_basis/module_pw/module_fft/fft_temp.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp similarity index 66% rename from source/module_basis/module_pw/module_fft/fft_temp.cpp rename to source/module_basis/module_pw/module_fft/fft_bundle.cpp index f15b6ac4d6..1878467e3f 100644 --- a/source/module_basis/module_pw/module_fft/fft_temp.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -1,5 +1,5 @@ #include -#include "fft_temp.h" +#include "fft_bundle.h" #include "fft_cpu.h" #include "module_base/module_device/device.h" // #if defined(__CUDA) @@ -17,10 +17,10 @@ std::unique_ptr make_unique(Args &&... args) namespace ModulePW { // #include "fft_gpu.h" -FFT_TEMP::FFT_TEMP() +FFT_Bundle::FFT_Bundle() { } -FFT_TEMP::FFT_TEMP(std::string device_in,std::string precision_in) +FFT_Bundle::FFT_Bundle(std::string device_in,std::string precision_in) { assert(device_in=="cpu" || device_in=="gpu"); assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); @@ -43,20 +43,20 @@ FFT_TEMP::FFT_TEMP(std::string device_in,std::string precision_in) // } } -FFT_TEMP::~FFT_TEMP() +FFT_Bundle::~FFT_Bundle() { } -void FFT_TEMP::set_device(std::string device_in) +void FFT_Bundle::set_device(std::string device_in) { this->device = device_in; } -void FFT_TEMP::set_precision(std::string precision_in) +void FFT_Bundle::set_precision(std::string precision_in) { this->precision = precision_in; } -void FFT_TEMP::setfft(std::string device_in,std::string precision_in) +void FFT_Bundle::setfft(std::string device_in,std::string precision_in) { assert(device_in=="cpu" || device_in=="gpu"); assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); @@ -64,7 +64,7 @@ void FFT_TEMP::setfft(std::string device_in,std::string precision_in) this->precision = precision_in; } -void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, +void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) { if (device=="cpu") @@ -112,12 +112,12 @@ void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); } } -void FFT_TEMP::initfftmode(int fft_mode_in) +void FFT_Bundle::initfftmode(int fft_mode_in) { this->fft_mode = fft_mode_in; } -void FFT_TEMP::setupFFT() +void FFT_Bundle::setupFFT() { if (double_flag) { @@ -129,7 +129,7 @@ void FFT_TEMP::setupFFT() } } -void FFT_TEMP::clearFFT() +void FFT_Bundle::clearFFT() { if (double_flag) { @@ -140,7 +140,7 @@ void FFT_TEMP::clearFFT() fft_float->cleanFFT(); } } -void FFT_TEMP::clear() +void FFT_Bundle::clear() { this->clearFFT(); if (float_flag) @@ -166,130 +166,130 @@ void FFT_TEMP::clear() } // access the real space data template <> -float* FFT_TEMP::get_rspace_data() const +float* FFT_Bundle::get_rspace_data() const { return fft_float->get_rspace_data(); } template <> -double* FFT_TEMP::get_rspace_data() const +double* FFT_Bundle::get_rspace_data() const { return fft_double->get_rspace_data(); } template <> -std::complex* FFT_TEMP::get_auxr_data() const +std::complex* FFT_Bundle::get_auxr_data() const { return fft_float->get_auxr_data(); } template <> -std::complex* FFT_TEMP::get_auxr_data() const +std::complex* FFT_Bundle::get_auxr_data() const { return fft_double->get_auxr_data(); } template <> -std::complex* FFT_TEMP::get_auxg_data() const +std::complex* FFT_Bundle::get_auxg_data() const { return fft_float->get_auxg_data(); } template <> -std::complex* FFT_TEMP::get_auxg_data() const +std::complex* FFT_Bundle::get_auxg_data() const { return fft_double->get_auxg_data(); } template <> -std::complex* FFT_TEMP::get_auxr_3d_data() const +std::complex* FFT_Bundle::get_auxr_3d_data() const { return fft_float->get_auxr_3d_data(); } template <> -std::complex* FFT_TEMP::get_auxr_3d_data() const +std::complex* FFT_Bundle::get_auxr_3d_data() const { return fft_double->get_auxr_3d_data(); } template <> -void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const +void FFT_Bundle::fftxyfor(std::complex* in, std::complex* out) const { fft_float->fftxyfor(in,out); } template <> -void FFT_TEMP::fftxyfor(std::complex* in, std::complex* out) const +void FFT_Bundle::fftxyfor(std::complex* in, std::complex* out) const { fft_double->fftxyfor(in,out); } template <> -void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const +void FFT_Bundle::fftzfor(std::complex* in, std::complex* out) const { fft_float->fftzfor(in,out); } template <> -void FFT_TEMP::fftzfor(std::complex* in, std::complex* out) const +void FFT_Bundle::fftzfor(std::complex* in, std::complex* out) const { fft_double->fftzfor(in,out); } template <> -void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const +void FFT_Bundle::fftxybac(std::complex* in, std::complex* out) const { fft_float->fftxybac(in,out); } template <> -void FFT_TEMP::fftxybac(std::complex* in, std::complex* out) const +void FFT_Bundle::fftxybac(std::complex* in, std::complex* out) const { fft_double->fftxybac(in,out); } template <> -void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const +void FFT_Bundle::fftzbac(std::complex* in, std::complex* out) const { fft_float->fftzbac(in,out); } template <> -void FFT_TEMP::fftzbac(std::complex* in, std::complex* out) const +void FFT_Bundle::fftzbac(std::complex* in, std::complex* out) const { fft_double->fftzbac(in,out); } template <> -void FFT_TEMP::fftxyr2c(float* in, std::complex* out) const +void FFT_Bundle::fftxyr2c(float* in, std::complex* out) const { fft_float->fftxyr2c(in,out); } template <> -void FFT_TEMP::fftxyr2c(double* in, std::complex* out) const +void FFT_Bundle::fftxyr2c(double* in, std::complex* out) const { fft_double->fftxyr2c(in,out); } template <> -void FFT_TEMP::fftxyc2r(std::complex* in, float* out) const +void FFT_Bundle::fftxyc2r(std::complex* in, float* out) const { fft_float->fftxyc2r(in,out); } template <> -void FFT_TEMP::fftxyc2r(std::complex* in, double* out) const +void FFT_Bundle::fftxyc2r(std::complex* in, double* out) const { fft_double->fftxyc2r(in,out); } template <> -void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const { fft_float->fft3D_forward(in, out); } template <> -void FFT_TEMP::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const { fft_double->fft3D_forward(in, out); } template <> -void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const { fft_float->fft3D_backward(in, out); } template <> -void FFT_TEMP::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const { fft_double->fft3D_backward(in, out); } diff --git a/source/module_basis/module_pw/module_fft/fft_temp.h b/source/module_basis/module_pw/module_fft/fft_bundle.h similarity index 95% rename from source/module_basis/module_pw/module_fft/fft_temp.h rename to source/module_basis/module_pw/module_fft/fft_bundle.h index c1121298fd..6cee57ee66 100644 --- a/source/module_basis/module_pw/module_fft/fft_temp.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -5,12 +5,12 @@ #define FFT_TEMP_H namespace ModulePW { -class FFT_TEMP +class FFT_Bundle { public: - FFT_TEMP(); - FFT_TEMP(std::string device_in,std::string precision_in); - ~FFT_TEMP(); + FFT_Bundle(); + FFT_Bundle(std::string device_in,std::string precision_in); + ~FFT_Bundle(); void setfft(std::string device_in,std::string precision_in); void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index df65c8ff1d..b5f2b58827 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -6,7 +6,7 @@ #include "module_base/vector3.h" #include #include "fft.h" -#include "module_fft/fft_temp.h" +#include "module_fft/fft_bundle.h" #include #ifdef __MPI #include "mpi.h" @@ -243,7 +243,7 @@ class PW_Basis int nmaxgr=0; // Gamma_only: max between npw and (nrxx+1)/2, others: max between npw and nrxx // Thus complex[nmaxgr] is able to contain either reciprocal or real data FFT ft; - FFT_TEMP ft1; + FFT_Bundle ft1; //The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform). template diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp index 12090b6db5..c7590446a7 100644 --- a/source/module_basis/module_pw/pw_transform.cpp +++ b/source/module_basis/module_pw/pw_transform.cpp @@ -1,5 +1,5 @@ #include "fft.h" -#include "module_fft/fft_temp.h" +#include "module_fft/fft_bundle.h" #include #include "pw_basis.h" #include diff --git a/source/module_basis/module_pw/test_serial/CMakeLists.txt b/source/module_basis/module_pw/test_serial/CMakeLists.txt index 0b75a0b0fc..028d5b3a0e 100644 --- a/source/module_basis/module_pw/test_serial/CMakeLists.txt +++ b/source/module_basis/module_pw/test_serial/CMakeLists.txt @@ -11,7 +11,7 @@ add_library( OBJECT ../fft.cpp ../module_fft/fft_base.cpp - ../module_fft/fft_temp.cpp + ../module_fft/fft_bundle.cpp ../module_fft/fft_cpu.cpp ../pw_basis.cpp ../pw_basis_k.cpp diff --git a/source/module_elecstate/test/charge_extra_test.cpp b/source/module_elecstate/test/charge_extra_test.cpp index 8c93133684..f52b034e4c 100644 --- a/source/module_elecstate/test/charge_extra_test.cpp +++ b/source/module_elecstate/test/charge_extra_test.cpp @@ -70,10 +70,10 @@ FFT::FFT() FFT::~FFT() { } -FFT_TEMP::FFT_TEMP() +FFT_Bundle::FFT_Bundle() { } -FFT_TEMP::~FFT_TEMP() +FFT_Bundle::~FFT_Bundle() { } void PW_Basis::initgrids(const double lat0_in, const ModuleBase::Matrix3 latvec_in, const double gridecut) diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index 9f8c302af2..95bb11949d 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -56,10 +56,10 @@ ModulePW::FFT::FFT() ModulePW::FFT::~FFT() { } -ModulePW::FFT_TEMP::FFT_TEMP() +ModulePW::FFT_Bundle::FFT_Bundle() { } -ModulePW::FFT_TEMP::~FFT_TEMP() +ModulePW::FFT_Bundle::~FFT_Bundle() { } void ModulePW::PW_Basis::initgrids(double, ModuleBase::Matrix3, double) diff --git a/source/module_hsolver/test/hsolver_pw_sup.h b/source/module_hsolver/test/hsolver_pw_sup.h index 28c8fd1e43..300492e5aa 100644 --- a/source/module_hsolver/test/hsolver_pw_sup.h +++ b/source/module_hsolver/test/hsolver_pw_sup.h @@ -4,8 +4,8 @@ namespace ModulePW { PW_Basis::PW_Basis(){}; PW_Basis::~PW_Basis(){}; -FFT_TEMP::FFT_TEMP(){}; -FFT_TEMP::~FFT_TEMP(){}; +FFT_Bundle::FFT_Bundle(){}; +FFT_Bundle::~FFT_Bundle(){}; void PW_Basis::initgrids( const double lat0_in, // unit length (unit in bohr) const ModuleBase::Matrix3 From 6cc4bacb236bdf7ceeabcc8d3c8887c7fa16324d Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Thu, 7 Nov 2024 23:41:34 +0800 Subject: [PATCH 19/27] save version of the pw_test and single version --- examples/scf/pw_Si2/INPUT | 4 +- source/module_basis/module_pw/CMakeLists.txt | 1 + .../module_pw/module_fft/fft_base.cpp | 30 ---------- .../module_pw/module_fft/fft_base.h | 25 +------- .../module_pw/module_fft/fft_bundle.cpp | 38 +++++------- .../module_pw/module_fft/fft_bundle.h | 2 +- .../module_pw/module_fft/fft_cpu.cpp | 60 +++++++++++++------ .../module_pw/module_fft/fft_cpu.h | 23 ++++++- .../module_pw/module_fft/fft_cpu_float.cpp | 28 ++++++++- 9 files changed, 111 insertions(+), 100 deletions(-) diff --git a/examples/scf/pw_Si2/INPUT b/examples/scf/pw_Si2/INPUT index cc4c54fc4c..f3144bd95d 100644 --- a/examples/scf/pw_Si2/INPUT +++ b/examples/scf/pw_Si2/INPUT @@ -4,9 +4,9 @@ pseudo_dir ../../../tests/PP_ORB symmetry 1 #Parameters (Accuracy) basis_type pw -ecutwfc 20 +ecutwfc 60 scf_thr 1e-7 scf_nmax 100 device cpu ks_solver dav_subspace -precision double +precision single diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index a5b50e59b7..b4ece143ff 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -18,6 +18,7 @@ list(APPEND objects module_fft/fft_base.cpp module_fft/fft_bundle.cpp module_fft/fft_cpu.cpp + ${FFT_SRC} ) add_library( diff --git a/source/module_basis/module_pw/module_fft/fft_base.cpp b/source/module_basis/module_pw/module_fft/fft_base.cpp index 1db284a7be..8cda835f24 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.cpp +++ b/source/module_basis/module_pw/module_fft/fft_base.cpp @@ -8,36 +8,6 @@ FFT_BASE::FFT_BASE() template FFT_BASE::~FFT_BASE() { - -} -template -void FFT_BASE::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, - int nproc_in, bool gamma_only_in, bool xprime_in, bool mpifft_in) -{ - this->gamma_only = gamma_only_in; - this->xprime = xprime_in; - this->fftnx = this->nx = nx_in; - this->fftny = this->ny = ny_in; - if (this->gamma_only) - { - if (xprime) { - this->fftnx = int(nx / 2) + 1; - } else { - this->fftny = int(ny / 2) + 1; -} - } - this->nz = nz_in; - this->ns = ns_in; - this->lixy = lixy_in; - this->rixy = rixy_in; - this->nplane = nplane_in; - this->nproc = nproc_in; - this->mpifft = mpifft_in; - this->nxy = this->nx * this->ny; - this->fftnxy = this->fftnx * this->fftny; - const int nrxx = this->nxy * this->nplane; - const int nsz = this->nz * this->ns; - this->maxgrids = (nsz > nrxx) ? nsz : nrxx; } template FFT_BASE::FFT_BASE(); diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index 35e6778a42..b53ce979f7 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -14,11 +14,10 @@ class FFT_BASE virtual ~FFT_BASE(); // init parameters of fft - virtual void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + virtual __attribute__((weak)) + void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false); - virtual __attribute__((weak)) void initfftmode(int fft_mode_in); - //init fftw_plans virtual void setupFFT()=0; @@ -54,29 +53,9 @@ class FFT_BASE virtual __attribute__((weak)) void fft3D_backward(std::complex* in, std::complex* out) const; protected: - int initflag = 0; // 0: not initialized; 1: initialized - int fftnx=0; - int fftny=0; - int fftnxy=0; int ny=0; int nx=0; int nz=0; - int nxy=0; - int nplane=0; //number of x-y planes - bool gamma_only = false; - int lixy=0; - int rixy=0;// lixy: the left edge of the pw ball in the y direction; rixy: the right edge of the pw ball in the x or y direction - bool mpifft = false; // if use mpi fft, only used when define __FFTW3_MPI - int maxgrids = 0; // maxgrids = (nsz > nrxx) ? nsz : nrxx; - bool xprime = true; // true: when do recip2real, x-fft will be done last and when doing real2recip, x-fft will be done first; false: y-fft - // For gamma_only, true: we use half x; false: we use half y - int ns=0; //number of sticks - int nproc=1; // number of proc. - int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive - -public: - void set_device(std::string device_); - void set_precision(std::string precision_); }; } diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 1878467e3f..21317394e8 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -67,14 +67,23 @@ void FFT_Bundle::setfft(std::string device_in,std::string precision_in) void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) { - if (device=="cpu") + if (this->precision=="single") { - fft_float = make_unique>(); - fft_double = make_unique>(); - // fft_double = new FFT_CPU(); #ifndef __ENABLE_FLOAT_FFTW - float_define = false; + float_define = false; #endif + float_flag = float_define; + double_flag = true; + } + if (this->precision=="double") + { + double_flag = true; + } + + if (device=="cpu") + { + fft_float = make_unique>(this->fft_mode); + fft_double = make_unique>(this->fft_mode); } if (device=="gpu") { @@ -86,29 +95,12 @@ void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_ // fft_double = make_unique>(); // #endif } - if (this->precision=="single") - { - #ifndef __ENABLE_FLOAT_FFTW - float_define = false; - #endif - float_flag = float_flag; - double_flag = true; - - - } - else if (this->precision=="double") - { - double_flag = true; - } - - if (float_flag && float_define) + if (float_flag) { - fft_float->initfftmode(this->fft_mode); fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); } if (double_flag) { - fft_double->initfftmode(this->fft_mode); fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); } } diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 6cee57ee66..be5de5f92f 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -58,7 +58,7 @@ class FFT_Bundle private: int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive bool float_flag=false; - bool float_define=false; + bool float_define=true; bool double_flag=false; // FFT_BASE* fft_float=nullptr; // Remove the qualified name and use a raw pointer instead std::shared_ptr> fft_float=nullptr; diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.cpp b/source/module_basis/module_pw/module_fft/fft_cpu.cpp index 6ac7c41628..f23079931a 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu.cpp @@ -23,6 +23,46 @@ FFT_CPU::~FFT_CPU() { } template <> +FFT_CPU::FFT_CPU(const int fft_mode_in) +{ + this->fft_mode = fft_mode_in; +} +template <> +FFT_CPU::FFT_CPU(const int fft_mode_in) +{ + this->fft_mode = fft_mode_in; +} + +template <> +void FFT_CPU::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + int nproc_in, bool gamma_only_in, bool xprime_in, bool mpifft_in) +{ + this->gamma_only = gamma_only_in; + this->xprime = xprime_in; + this->fftnx = this->nx = nx_in; + this->fftny = this->ny = ny_in; + if (this->gamma_only) + { + if (xprime) { + this->fftnx = int(nx / 2) + 1; + } else { + this->fftny = int(ny / 2) + 1; + } + } + this->nz = nz_in; + this->ns = ns_in; + this->lixy = lixy_in; + this->rixy = rixy_in; + this->nplane = nplane_in; + this->nproc = nproc_in; + this->mpifft = mpifft_in; + this->nxy = this->nx * this->ny; + this->fftnxy = this->fftnx * this->fftny; + const int nrxx = this->nxy * this->nplane; + const int nsz = this->nz * this->ns; + this->maxgrids = (nsz > nrxx) ? nsz : nrxx; +} +template <> void FFT_CPU::setupFFT() { @@ -44,8 +84,8 @@ void FFT_CPU::setupFFT() default: break; } - if (!this->mpifft) - { + // if (!this->mpifft) + // { z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); d_rspace = (double*)z_auxg; @@ -109,23 +149,9 @@ void FFT_CPU::setupFFT() 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); } } - } -#if defined(__FFTW3_MPI) && defined(__MPI) - else - { - // this->initplan_mpi(); - // if (this->precision == "single") { - // this->initplanf_mpi(); - // } - } -#endif + // } return; } -template <> -void FFT_CPU::initfftmode(int fft_mode_in) -{ - this->fft_mode = fft_mode_in; -} template <> void FFT_CPU::clearfft(fftw_plan& plan) diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.h b/source/module_basis/module_pw/module_fft/fft_cpu.h index 1da1d53956..7be962d2fa 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.h +++ b/source/module_basis/module_pw/module_fft/fft_cpu.h @@ -14,11 +14,13 @@ class FFT_CPU : public FFT_BASE { public: FFT_CPU(); + FFT_CPU(const int fft_mode_in); ~FFT_CPU(); - __attribute__((weak)) void initfftmode(int fft_mode_in); - //init fftw_plans + // __attribute__((weak)) + __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false) override; __attribute__((weak)) void setupFFT() override; // void initplan(const unsigned int& flag = 0); @@ -80,6 +82,23 @@ class FFT_CPU : public FFT_BASE float* s_rspace = nullptr; // real number space for r, [nplane * nx *ny] double* d_rspace = nullptr; // real number space for r, [nplane * nx *ny] + + int initflag = 0; // 0: not initialized; 1: initialized + int fftnx=0; + int fftny=0; + int fftnxy=0; + int nxy=0; + int nplane=0; //number of x-y planes + bool gamma_only = false; + int lixy=0; + int rixy=0;// lixy: the left edge of the pw ball in the y direction; rixy: the right edge of the pw ball in the x or y direction + bool mpifft = false; // if use mpi fft, only used when define __FFTW3_MPI + int maxgrids = 0; // maxgrids = (nsz > nrxx) ? nsz : nrxx; + bool xprime = true; // true: when do recip2real, x-fft will be done last and when doing real2recip, x-fft will be done first; false: y-fft + // For gamma_only, true: we use half x; false: we use half y + int ns=0; //number of sticks + int nproc=1; // number of proc. + int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive }; } #endif // FFT_CPU_H \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp index 8c59b8e42a..9a35cee70f 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp @@ -7,9 +7,33 @@ namespace ModulePW { template <> -void FFT_CPU::initfftmode(int fft_mode_in) +void FFT_CPU::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + int nproc_in, bool gamma_only_in, bool xprime_in, bool mpifft_in) { - this->fft_mode = fft_mode_in; + this->gamma_only = gamma_only_in; + this->xprime = xprime_in; + this->fftnx = this->nx = nx_in; + this->fftny = this->ny = ny_in; + if (this->gamma_only) + { + if (this->xprime) { + this->fftnx = int(nx / 2) + 1; + } else { + this->fftny = int(ny / 2) + 1; + } + } + this->nz = nz_in; + this->ns = ns_in; + this->lixy = lixy_in; + this->rixy = rixy_in; + this->nplane = nplane_in; + this->nproc = nproc_in; + this->mpifft = mpifft_in; + this->nxy = this->nx * this->ny; + this->fftnxy = this->fftnx * this->fftny; + const int nrxx = this->nxy * this->nplane; + const int nsz = this->nz * this->ns; + this->maxgrids = (nsz > nrxx) ? nsz : nrxx; } template <> void FFT_CPU::setupFFT() From 2d0b5f3ec6ddb7a9af4f939613e3d114766e0b60 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Fri, 8 Nov 2024 10:18:03 +0800 Subject: [PATCH 20/27] fix complie bug and change the fftwf logic --- .../module_pw/module_fft/fft_bundle.cpp | 30 +++---------------- .../module_pw/module_fft/fft_cpu.cpp | 3 -- .../module_pw/module_fft/fft_cpu_float.cpp | 24 +++------------ .../module_xc/test/CMakeLists.txt | 4 +-- 4 files changed, 10 insertions(+), 51 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 21317394e8..8f5a8fbc0a 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -26,20 +26,10 @@ FFT_Bundle::FFT_Bundle(std::string device_in,std::string precision_in) assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); this->device = device_in; this->precision = precision_in; - if (device=="cpu") - { - fft_float = make_unique>(); - fft_double = make_unique>(); - } - // else if (device=="gpu") - // { - // #if defined(__ROCM) - // fft_float = new FFT_RCOM(); - // fft_double = new FFT_RCOM(); - // #elif defined(__CUDA) - // fft_float = new FFT_CUDA(); - // fft_double = new FFT_CUDA(); - // #endif + // if (device=="cpu") + // { + // fft_float = make_unique>(); + // fft_double = make_unique>(); // } } @@ -143,18 +133,6 @@ void FFT_Bundle::clear() { fft_double->clear(); } - // if (fft_float!=nullptr) - // { - // delete fft_float; - // fft_float=nullptr; - // float_flag = false; - // } - // if (fft_double!=nullptr) - // { - // delete fft_double; - // fft_double=nullptr; - // double_flag = false; - // } } // access the real space data template <> diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.cpp b/source/module_basis/module_pw/module_fft/fft_cpu.cpp index f23079931a..894f899f39 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu.cpp @@ -84,8 +84,6 @@ void FFT_CPU::setupFFT() default: break; } - // if (!this->mpifft) - // { z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); d_rspace = (double*)z_auxg; @@ -149,7 +147,6 @@ void FFT_CPU::setupFFT() 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); } } - // } return; } diff --git a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp index 9a35cee70f..0873acff39 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp @@ -1,9 +1,5 @@ #include "fft_cpu.h" -// #include "fftw3f.h" -// #if defined(__FFTW3_MPI) && defined(__MPI) -// #include "fftw3f-mpi.h" -// //#include "fftw3-mpi_mkl.h" namespace ModulePW { template <> @@ -56,8 +52,8 @@ void FFT_CPU::setupFFT() default: break; } - if (!this->mpifft) - { + // if (!this->mpifft) + // { c_auxg = (std::complex*)fftwf_malloc(sizeof(fftwf_complex) * this->maxgrids); c_auxr = (std::complex*)fftwf_malloc(sizeof(fftwf_complex) * this->maxgrids); s_rspace = (float*)c_auxg; @@ -131,16 +127,6 @@ void FFT_CPU::setupFFT() (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); } } - } - #if defined(__FFTW3_MPI) && defined(__MPI) - else - { - // this->initplan_mpi(); - // if (this->precision == "single") { - // this->initplanf_mpi(); - // } - } - #endif return; } @@ -308,14 +294,12 @@ void FFT_CPU::fftxyc2r(std::complex* in, float* out) const } else { - fftwf_execute_dft(this->planfxfor1, (fftwf_complex*)in, (fftwf_complex*)in); + fftwf_execute_dft(this->planfxbac1, (fftwf_complex*)in, (fftwf_complex*)in); for (int i = 0; i < this->nx; ++i) { - fftwf_execute_dft(this->planfybac, (fftwf_complex*)&in[i * npy], (fftwf_complex*)&in[i * npy]); + fftwf_execute_dft_c2r(this->planfyc2r, (fftwf_complex*)&in[i * npy], &out[i * npy]); } - - fftwf_execute_dft_c2r(this->planfyc2r, (fftwf_complex*)in, out); } } } \ No newline at end of file diff --git a/source/module_hamilt_general/module_xc/test/CMakeLists.txt b/source/module_hamilt_general/module_xc/test/CMakeLists.txt index 0b2ea5ec98..66cf5f9cb0 100644 --- a/source/module_hamilt_general/module_xc/test/CMakeLists.txt +++ b/source/module_hamilt_general/module_xc/test/CMakeLists.txt @@ -39,7 +39,7 @@ AddTest( ../../../module_base/libm/sincos.cpp ../../../module_base/blas_connector.cpp ../../../module_basis/module_pw/module_fft/fft_base.cpp - ../../../module_basis/module_pw/module_fft/fft_temp.cpp + ../../../module_basis/module_pw/module_fft/fft_bundle.cpp ../../../module_basis/module_pw/module_fft/fft_cpu.cpp ) @@ -77,6 +77,6 @@ AddTest( ../../../module_base/libm/branred.cpp ../../../module_base/libm/sincos.cpp ../../../module_basis/module_pw/module_fft/fft_base.cpp - ../../../module_basis/module_pw/module_fft/fft_temp.cpp + ../../../module_basis/module_pw/module_fft/fft_bundle.cpp ../../../module_basis/module_pw/module_fft/fft_cpu.cpp ) \ No newline at end of file From da26acc0b885b8a8fe1dc7b8aa4a73cee153f7c3 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Fri, 8 Nov 2024 16:34:14 +0800 Subject: [PATCH 21/27] add comments for the fft class --- examples/scf/pw_Si2/INPUT | 2 +- .../module_pw/module_fft/fft_base.h | 145 ++++++++++++--- .../module_pw/module_fft/fft_bundle.cpp | 36 ++-- .../module_pw/module_fft/fft_bundle.h | 169 ++++++++++++++++-- .../module_pw/module_fft/fft_cpu.cpp | 23 +-- .../module_pw/module_fft/fft_cpu.h | 146 +++++++++++---- .../module_pw/module_fft/fft_cpu_float.cpp | 159 +++++++--------- 7 files changed, 486 insertions(+), 194 deletions(-) diff --git a/examples/scf/pw_Si2/INPUT b/examples/scf/pw_Si2/INPUT index f3144bd95d..141c104100 100644 --- a/examples/scf/pw_Si2/INPUT +++ b/examples/scf/pw_Si2/INPUT @@ -9,4 +9,4 @@ scf_thr 1e-7 scf_nmax 100 device cpu ks_solver dav_subspace -precision single +precision double diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index b53ce979f7..3997079898 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -13,50 +13,151 @@ class FFT_BASE FFT_BASE(); virtual ~FFT_BASE(); - // init parameters of fft + /** + * @brief Initialize the fft parameters As virtual function. + * + * The function is used to initialize the fft parameters. + */ virtual __attribute__((weak)) - void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, - int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false); + void initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in = true); - //init fftw_plans + /** + * @brief Setup the fft Plan and data As pure virtual function. + * + * The function is set as pure virtual function.In order to + * override the function in the derived class.In the derived + * class, the function is used to setup the fft Plan and data. + */ virtual void setupFFT()=0; - //destroy fftw_plans + /** + * @brief Clean the fft Plan As pure virtual function. + * + * The function is set as pure virtual function.In order to + * override the function in the derived class.In the derived + * class, the function is used to clean the fft Plan. + */ virtual void cleanFFT()=0; - //clear fftw_data + + /** + * @brief Clear the fft data As pure virtual function. + * + * The function is set as pure virtual function.In order to + * override the function in the derived class.In the derived + * class, the function is used to clear the fft data. + */ virtual void clear()=0; - // access the real space data - virtual __attribute__((weak)) FPTYPE* get_rspace_data() const; + /** + * @brief Get the real space data in cpu-like fft + * + * The function is used to get the real space data.While the + * FFT_BASE is an abstract class,the function will be override, + * The attribute weak is used to avoid define the function. + */ + virtual __attribute__((weak)) + FPTYPE* get_rspace_data() const; - virtual __attribute__((weak)) std::complex* get_auxr_data() const; + virtual __attribute__((weak)) + std::complex* get_auxr_data() const; - virtual __attribute__((weak)) std::complex* get_auxg_data() const; + virtual __attribute__((weak)) + std::complex* get_auxg_data() const; - virtual __attribute__((weak)) std::complex* get_auxr_3d_data() const; + /** + * @brief Get the auxiliary real space data in 3D + * + * The function is used to get the auxiliary real space data in 3D. + * While the FFT_BASE is an abstract class,the function will be override, + * The attribute weak is used to avoid define the function. + */ + virtual __attribute__((weak)) + std::complex* get_auxr_3d_data() const; //forward fft in x-y direction - virtual __attribute__((weak)) void fftxyfor(std::complex* in, std::complex* out) const; - virtual __attribute__((weak)) void fftxybac(std::complex* in, std::complex* out) const; + /** + * @brief Forward FFT in x-y direction + * @param in input data + * @param out output data + * + * This function performs the forward FFT in the x-y direction. + * It involves two axes, x and y. The FFT is applied multiple times + * along the left and right boundaries in the primary direction(which is + * determined by the xprime flag).Notably, the Y axis operates in + * "many-many-FFT" mode. + */ + virtual __attribute__((weak)) + void fftxyfor(std::complex* in, + std::complex* out) const; + + virtual __attribute__((weak)) + void fftxybac(std::complex* in, + std::complex* out) const; - virtual __attribute__((weak)) void fftzfor(std::complex* in, std::complex* out) const; + /** + * @brief Forward FFT in z direction + * @param in input data + * @param out output data + * + * This function performs the forward FFT in the z direction. + * It involves only one axis, z. The FFT is applied only once. + * Notably, the Z axis operates in many FFT with nz*ns. + */ + virtual __attribute__((weak)) + void fftzfor(std::complex* in, + std::complex* out) const; - virtual __attribute__((weak)) void fftzbac(std::complex* in, std::complex* out) const; + virtual __attribute__((weak)) + void fftzbac(std::complex* in, + std::complex* out) const; - virtual __attribute__((weak)) void fftxyr2c(FPTYPE* in, std::complex* out) const; + /** + * @brief Forward FFT in x-y direction with real to complex + * @param in input data, real type + * @param out output data, complex type + * + * This function performs the forward FFT in the x-y direction + * with real to complex.There is no difference between fftxyfor. + */ + virtual __attribute__((weak)) + void fftxyr2c(FPTYPE* in, + std::complex* out) const; - virtual __attribute__((weak)) void fftxyc2r(std::complex* in, FPTYPE* out) const; + virtual __attribute__((weak)) + void fftxyc2r(std::complex* in, + FPTYPE* out) const; - virtual __attribute__((weak)) void fft3D_forward(std::complex* in, std::complex* out) const; + /** + * @brief Forward FFT in 3D + * @param in input data + * @param out output data + * + * This function performs the forward FFT for gpu-like fft. + * It involves three axes, x, y, and z. The FFT is applied multiple times + * for fft3D_forward. + */ + virtual __attribute__((weak)) + void fft3D_forward(std::complex* in, + std::complex* out) const; - virtual __attribute__((weak)) void fft3D_backward(std::complex* in, std::complex* out) const; + virtual __attribute__((weak)) + void fft3D_backward(std::complex* in, + std::complex* out) const; protected: - int ny=0; - int nx=0; + int nx=0; + int ny=0; int nz=0; - }; } #endif // FFT_BASE_H diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 8f5a8fbc0a..7c084c7be9 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -16,7 +16,6 @@ std::unique_ptr make_unique(Args &&... args) } namespace ModulePW { -// #include "fft_gpu.h" FFT_Bundle::FFT_Bundle() { } @@ -26,11 +25,6 @@ FFT_Bundle::FFT_Bundle(std::string device_in,std::string precision_in) assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); this->device = device_in; this->precision = precision_in; - // if (device=="cpu") - // { - // fft_float = make_unique>(); - // fft_double = make_unique>(); - // } } FFT_Bundle::~FFT_Bundle() @@ -54,8 +48,17 @@ void FFT_Bundle::setfft(std::string device_in,std::string precision_in) this->precision = precision_in; } -void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, - int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in) +void FFT_Bundle::initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in , + bool mpifft_in) { if (this->precision=="single") { @@ -74,6 +77,14 @@ void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_ { fft_float = make_unique>(this->fft_mode); fft_double = make_unique>(this->fft_mode); + if (float_flag) + { + fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in); + } + if (double_flag) + { + fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in); + } } if (device=="gpu") { @@ -85,14 +96,7 @@ void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_ // fft_double = make_unique>(); // #endif } - if (float_flag) - { - fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); - } - if (double_flag) - { - fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in); - } + } void FFT_Bundle::initfftmode(int fft_mode_in) { diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index be5de5f92f..43ae8cdd16 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -9,13 +9,65 @@ class FFT_Bundle { public: FFT_Bundle(); + /** + * @brief Constructor with device and precision. + * @param device_in device type, cpu or gpu. + * @param precision_in precision type, single or double. + * + * the function will check the input device and precision, + * and set the device and precision. + */ FFT_Bundle(std::string device_in,std::string precision_in); ~FFT_Bundle(); + /** + * @brief Set device and precision. + * @param device_in device type, cpu or gpu. + * @param precision_in precision type, single or double. + * + * the function will check the input device and precision, + * and set the device and precision. + */ void setfft(std::string device_in,std::string precision_in); - void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, - int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false); + + /** + * @brief Initialize the fft parameters. + * @param nx_in number of grid points in x direction. + * @param ny_in number of grid points in y direction. + * @param nz_in number of grid points in z direction. + * @param lixy_in the position of the left boundary + * in the x-y plane. + * @param rixy_in the position of the right boundary + * in the x-y plane. + * @param ns_in number of stick whcih is used in the + * Z direction. + * @param nplane_in number of x-y planes. + * @param nproc_in number of processors. + * @param gamma_only_in whether only gamma point is used. + * @param xprime_in whether xprime is used. + * + * the function will initialize the many-fft parameters + * Wheatley in cpu or gpu device. + */ + void initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in = true, + bool mpifft_in = false); + /** + * @brief Initialize the fft mode. + * @param fft_mode_in fft mode. + * + * the function will initialize the fft mode. + */ + void initfftmode(int fft_mode_in); void setupFFT(); @@ -24,43 +76,136 @@ class FFT_Bundle void clear(); + /** + * @brief Get the real space data. + * @return FPTYPE* the real space data. + * + * the function will return the real space data, + * which is used in the cpu-like fft. + */ template FPTYPE* get_rspace_data() const; + /** + * @brief Get the auxr data. + * @return std::complex* the auxr data. + * + * the function will return the auxr data, + * which is used in the cpu-like fft. + */ template std::complex* get_auxr_data() const; + /** + * @brief Get the auxg data. + * @return std::complex* the auxg data. + * + * the function will return the auxg data, + * which is used in the cpu-like fft. + */ template std::complex* get_auxg_data() const; + /** + * @brief Get the auxr 3d data. + * @return std::complex* the auxr 3d data. + * + * the function will return the auxr 3d data, + * which is used in the gpu-like fft. + */ template std::complex* get_auxr_3d_data() const; + /** + * @brief Forward fft in z direction. + * @param in input data. + * @param out output data. + * + * The function will do the forward many fft in z direction, + * As an interface, the function will call the fftzfor in the + * accurate fft class. + * which is used in the cpu-like fft. + */ template - void fftzfor(std::complex* in, std::complex* out) const; + void fftzfor(std::complex* in, + std::complex* out) const; + /** + * @brief Forward fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the forward fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyfor in the accurate fft class. + */ template - void fftxyfor(std::complex* in, std::complex* out) const; + void fftxyfor(std::complex* in, + std::complex* out) const; + /** + * @brief Backward fft in z direction. + * @param in input data. + * @param out output data. + * + * the function will do the backward many fft in z direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftzbac in the accurate fft class. + */ template - void fftzbac(std::complex* in, std::complex* out) const; + void fftzbac(std::complex* in, + std::complex* out) const; + /** + * @brief Backward fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the backward fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxybac in the accurate fft class. + */ template - void fftxybac(std::complex* in, std::complex* out) const; + void fftxybac(std::complex* in, + std::complex* out) const; + + /** + * @brief Real to complex fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the real to complex fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyr2c in the accurate fft class. + */ template - void fftxyr2c(FPTYPE* in, std::complex* out) const; + void fftxyr2c(FPTYPE* in, + std::complex* out) const; + /** + * @brief Complex to real fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the complex to real fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyc2r in the accurate fft class. + */ template - void fftxyc2r(std::complex* in, FPTYPE* out) const; + void fftxyc2r(std::complex* in, + FPTYPE* out) const; template - void fft3D_forward(const Device* ctx, std::complex* in, std::complex* out) const; + void fft3D_forward(const Device* ctx, + std::complex* in, + std::complex* out) const; template - void fft3D_backward(const Device* ctx, std::complex* in, std::complex* out) const; + void fft3D_backward(const Device* ctx, + std::complex* in, + std::complex* out) const; void set_device(std::string device_in); void set_precision(std::string precision_in); private: - int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive + int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive bool float_flag=false; bool float_define=true; bool double_flag=false; - // FFT_BASE* fft_float=nullptr; // Remove the qualified name and use a raw pointer instead std::shared_ptr> fft_float=nullptr; std::shared_ptr> fft_double=nullptr; diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.cpp b/source/module_basis/module_pw/module_fft/fft_cpu.cpp index 894f899f39..a7cabae803 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu.cpp @@ -1,9 +1,5 @@ #include "fft_cpu.h" #include "fftw3.h" -#if defined(__FFTW3_MPI) && defined(__MPI) -#include -//#include "fftw3-mpi_mkl.h" -#endif namespace ModulePW { template <> @@ -33,9 +29,17 @@ FFT_CPU::FFT_CPU(const int fft_mode_in) this->fft_mode = fft_mode_in; } -template <> -void FFT_CPU::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, - int nproc_in, bool gamma_only_in, bool xprime_in, bool mpifft_in) +template +void FFT_CPU::initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in) { this->gamma_only = gamma_only_in; this->xprime = xprime_in; @@ -44,9 +48,9 @@ void FFT_CPU::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int if (this->gamma_only) { if (xprime) { - this->fftnx = int(nx / 2) + 1; + this->fftnx = int(this->nx / 2) + 1; } else { - this->fftny = int(ny / 2) + 1; + this->fftny = int(this->ny / 2) + 1; } } this->nz = nz_in; @@ -55,7 +59,6 @@ void FFT_CPU::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int this->rixy = rixy_in; this->nplane = nplane_in; this->nproc = nproc_in; - this->mpifft = mpifft_in; this->nxy = this->nx * this->ny; this->fftnxy = this->fftnx * this->fftny; const int nrxx = this->nxy * this->nplane; diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.h b/source/module_basis/module_pw/module_fft/fft_cpu.h index 7be962d2fa..1db61bf841 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.h +++ b/source/module_basis/module_pw/module_fft/fft_cpu.h @@ -17,39 +17,96 @@ class FFT_CPU : public FFT_BASE FFT_CPU(const int fft_mode_in); ~FFT_CPU(); - //init fftw_plans - // __attribute__((weak)) - __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, - int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false) override; - __attribute__((weak)) void setupFFT() override; + /** + * @brief Initialize the fft parameters. + * @param nx_in number of grid points in x direction. + * @param ny_in number of grid points in y direction. + * @param nz_in number of grid points in z direction. + * @param lixy_in the position of the left boundary + * in the x-y plane. + * @param rixy_in the position of the right boundary + * in the x-y plane. + * @param ns_in number of stick whcih is used in the + * Z direction. + * @param nplane_in number of x-y planes. + * @param nproc_in number of processors. + * @param gamma_only_in whether only gamma point is used. + * @param xprime_in whether xprime is used. + */ + __attribute__((weak)) + void initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in = true) override; + __attribute__((weak)) + void setupFFT() override; // void initplan(const unsigned int& flag = 0); - __attribute__((weak)) void cleanFFT() override; - - __attribute__((weak)) void clear() override; - - __attribute__((weak)) FPTYPE* get_rspace_data() const override; - - __attribute__((weak)) std::complex* get_auxr_data() const; - - __attribute__((weak)) std::complex* get_auxg_data() const; - - __attribute__((weak)) void fftxyfor(std::complex* in, std::complex* out) const override; - - __attribute__((weak)) void fftxybac(std::complex* in, std::complex* out) const override; - - __attribute__((weak)) void fftzfor(std::complex* in, std::complex* out) const override; - - __attribute__((weak)) void fftzbac(std::complex* in, std::complex* out) const override; - - __attribute__((weak)) void fftxyr2c(FPTYPE* in, std::complex* out) const override; - - __attribute__((weak)) void fftxyc2r(std::complex* in, FPTYPE* out) const override; + __attribute__((weak)) + void cleanFFT() override; + + __attribute__((weak)) + void clear() override; + + /** + * @brief Get the real space data the CPU FFT. + * @return FPTYPE* the real space data. + * + * the function will return the real space data, + * which is used in the CPU fft.Use the weak attribute + * to avoid defining float while without flag ENABLE_FLOAT_FFTW. + */ + __attribute__((weak)) + FPTYPE* get_rspace_data() const override; + + __attribute__((weak)) + std::complex* get_auxr_data() const; + + __attribute__((weak)) + std::complex* get_auxg_data() const; + + /** + * @brief Forward FFT in x-y direction + * @param in input data + * @param out output data + * + * The function details can be found in FFT_BASE, + * and the function interfaces can be found in FFT_BUNDLE. + */ + __attribute__((weak)) + void fftxyfor(std::complex* in, + std::complex* out) const override; + + __attribute__((weak)) + void fftxybac(std::complex* in, + std::complex* out) const override; + + __attribute__((weak)) + void fftzfor(std::complex* in, + std::complex* out) const override; + + __attribute__((weak)) + void fftzbac(std::complex* in, + std::complex* out) const override; + + __attribute__((weak)) + void fftxyr2c(FPTYPE* in, + std::complex* out) const override; + + __attribute__((weak)) + void fftxyc2r(std::complex* in, + FPTYPE* out) const override; private: void clearfft(fftw_plan& plan); void clearfft(fftwf_plan& plan); - fftw_plan planzfor = NULL;//create a special pointer pointing to the fftw_plan class as a plan for performing FFT + fftw_plan planzfor = NULL; fftw_plan planzbac = NULL; fftw_plan planxfor1 = NULL; fftw_plan planxbac1 = NULL; @@ -82,23 +139,36 @@ class FFT_CPU : public FFT_BASE float* s_rspace = nullptr; // real number space for r, [nplane * nx *ny] double* d_rspace = nullptr; // real number space for r, [nplane * nx *ny] - - int initflag = 0; // 0: not initialized; 1: initialized int fftnx=0; int fftny=0; int fftnxy=0; int nxy=0; - int nplane=0; //number of x-y planes - bool gamma_only = false; - int lixy=0; - int rixy=0;// lixy: the left edge of the pw ball in the y direction; rixy: the right edge of the pw ball in the x or y direction - bool mpifft = false; // if use mpi fft, only used when define __FFTW3_MPI - int maxgrids = 0; // maxgrids = (nsz > nrxx) ? nsz : nrxx; - bool xprime = true; // true: when do recip2real, x-fft will be done last and when doing real2recip, x-fft will be done first; false: y-fft - // For gamma_only, true: we use half x; false: we use half y + int nplane=0; int ns=0; //number of sticks int nproc=1; // number of proc. - int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive + int maxgrids = 0; + bool gamma_only = false; + + /** + * @brief lixy: the left edge of the pw ball in the y direction + */ + int lixy=0; + + /** + * @brief rixy: the right edge of the pw ball in the x or y direction + */ + int rixy=0; + + /** + * @brief xprime: whether xprime is used,when do recip2real, x-fft will + * be done last and when doing real2recip, x-fft will be done first; + * false: y-fft For gamma_only, true: we use half x; false: we use half y + */ + bool xprime = true; + /** + * @brief fft_mode: fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive + */ + int fft_mode = 0; }; } #endif // FFT_CPU_H \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp index 0873acff39..627a2aa2f8 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp @@ -3,35 +3,6 @@ namespace ModulePW { template <> -void FFT_CPU::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, - int nproc_in, bool gamma_only_in, bool xprime_in, bool mpifft_in) -{ - this->gamma_only = gamma_only_in; - this->xprime = xprime_in; - this->fftnx = this->nx = nx_in; - this->fftny = this->ny = ny_in; - if (this->gamma_only) - { - if (this->xprime) { - this->fftnx = int(nx / 2) + 1; - } else { - this->fftny = int(ny / 2) + 1; - } - } - this->nz = nz_in; - this->ns = ns_in; - this->lixy = lixy_in; - this->rixy = rixy_in; - this->nplane = nplane_in; - this->nproc = nproc_in; - this->mpifft = mpifft_in; - this->nxy = this->nx * this->ny; - this->fftnxy = this->fftnx * this->fftny; - const int nrxx = this->nxy * this->nplane; - const int nsz = this->nz * this->ns; - this->maxgrids = (nsz > nrxx) ? nsz : nrxx; -} -template <> void FFT_CPU::setupFFT() { unsigned int flag = FFTW_ESTIMATE; @@ -52,81 +23,79 @@ void FFT_CPU::setupFFT() default: break; } - // if (!this->mpifft) - // { - c_auxg = (std::complex*)fftwf_malloc(sizeof(fftwf_complex) * this->maxgrids); - c_auxr = (std::complex*)fftwf_malloc(sizeof(fftwf_complex) * this->maxgrids); - s_rspace = (float*)c_auxg; - //--------------------------------------------------------- - // 1 D - //--------------------------------------------------------- + c_auxg = (std::complex*)fftwf_malloc(sizeof(fftwf_complex) * this->maxgrids); + c_auxr = (std::complex*)fftwf_malloc(sizeof(fftwf_complex) * this->maxgrids); + s_rspace = (float*)c_auxg; + //--------------------------------------------------------- + // 1 D + //--------------------------------------------------------- - // fftw_plan_many_dft(int rank, const int *n, int howmany, - // fftw_complex *in, const int *inembed, int istride, int idist, - // fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned - //flags); + // fftw_plan_many_dft(int rank, const int *n, int howmany, + // fftw_complex *in, const int *inembed, int istride, int idist, + // fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned + //flags); - this->planfzfor = fftwf_plan_many_dft(1, &this->nz, this->ns, (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, - (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); + this->planfzfor = fftwf_plan_many_dft(1, &this->nz, this->ns, (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, + (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); - this->planfzbac = fftwf_plan_many_dft(1, &this->nz, this->ns, (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, - (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); - //--------------------------------------------------------- - // 2 D - //--------------------------------------------------------- + this->planfzbac = fftwf_plan_many_dft(1, &this->nz, this->ns, (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, + (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); + //--------------------------------------------------------- + // 2 D + //--------------------------------------------------------- - int* embed = nullptr; - int npy = this->nplane * this->ny; - if (this->xprime) + int* embed = nullptr; + int npy = this->nplane * this->ny; + if (this->xprime) + { + this->planfyfor = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, nplane, 1, + (fftwf_complex*)c_auxr, embed, nplane, 1, FFTW_FORWARD, flag); + this->planfybac = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, nplane, 1, + (fftwf_complex*)c_auxr, embed, nplane, 1, FFTW_BACKWARD, flag); + if (this->gamma_only) { - this->planfyfor = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, nplane, 1, - (fftwf_complex*)c_auxr, embed, nplane, 1, FFTW_FORWARD, flag); - this->planfybac = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, nplane, 1, - (fftwf_complex*)c_auxr, embed, nplane, 1, FFTW_BACKWARD, flag); - if (this->gamma_only) - { - this->planfxr2c = fftwf_plan_many_dft_r2c(1, &this->nx, npy, s_rspace, embed, npy, 1, - (fftwf_complex*)c_auxr, embed, npy, 1, flag); - this->planfxc2r = fftwf_plan_many_dft_c2r(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, - s_rspace, embed, npy, 1, flag); - } - else - { - this->planfxfor1 = fftwf_plan_many_dft(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, - (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planfxbac1 = fftwf_plan_many_dft(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, - (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - } + this->planfxr2c = fftwf_plan_many_dft_r2c(1, &this->nx, npy, s_rspace, embed, npy, 1, + (fftwf_complex*)c_auxr, embed, npy, 1, flag); + this->planfxc2r = fftwf_plan_many_dft_c2r(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, + s_rspace, embed, npy, 1, flag); } else { - this->planfxfor1 = fftwf_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planfxbac1 = fftwf_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - if (this->gamma_only) - { - this->planfyr2c = fftwf_plan_many_dft_r2c(1, &this->ny, this->nplane, s_rspace, embed, this->nplane, 1, - (fftwf_complex*)c_auxr, embed, this->nplane, 1, flag); - this->planfyc2r = fftwf_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, - this->nplane, 1, s_rspace, embed, this->nplane, 1, flag); - } - else - { - this->planfxfor2 - = fftwf_plan_many_dft(1, &this->nx, this->nplane * (this->ny - rixy), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planfxbac2 - = fftwf_plan_many_dft(1, &this->nx, this->nplane * (this->ny - rixy), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - this->planfyfor - = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, this->nplane, 1, - (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); - this->planfybac - = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, this->nplane, 1, - (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); - } + this->planfxfor1 = fftwf_plan_many_dft(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, + (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planfxbac1 = fftwf_plan_many_dft(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, + (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); } + } + else + { + this->planfxfor1 = fftwf_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftwf_complex*)c_auxr, embed, + npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planfxbac1 = fftwf_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftwf_complex*)c_auxr, embed, + npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + if (this->gamma_only) + { + this->planfyr2c = fftwf_plan_many_dft_r2c(1, &this->ny, this->nplane, s_rspace, embed, this->nplane, 1, + (fftwf_complex*)c_auxr, embed, this->nplane, 1, flag); + this->planfyc2r = fftwf_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, + this->nplane, 1, s_rspace, embed, this->nplane, 1, flag); + } + else + { + this->planfxfor2 + = fftwf_plan_many_dft(1, &this->nx, this->nplane * (this->ny - rixy), (fftwf_complex*)c_auxr, embed, + npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planfxbac2 + = fftwf_plan_many_dft(1, &this->nx, this->nplane * (this->ny - rixy), (fftwf_complex*)c_auxr, embed, + npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planfyfor + = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, this->nplane, 1, + (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); + this->planfybac + = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, this->nplane, 1, + (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); + } + } return; } From 9b200a22dcab0d22b9738df5a4710b5e683ca1c1 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Fri, 8 Nov 2024 16:47:20 +0800 Subject: [PATCH 22/27] modify the fft name and add comments --- .../module_pw/module_fft/fft_cpu.cpp | 110 +++++++++--------- source/module_basis/module_pw/pw_basis.cpp | 10 +- source/module_basis/module_pw/pw_basis.h | 2 +- source/module_basis/module_pw/pw_basis_k.cpp | 10 +- .../module_basis/module_pw/pw_basis_sup.cpp | 8 +- .../module_basis/module_pw/pw_transform.cpp | 62 +++++----- .../module_basis/module_pw/pw_transform_k.cpp | 48 ++++---- .../module_elecstate/module_charge/charge.cpp | 4 +- .../module_charge/charge_init.cpp | 4 +- source/module_esolver/esolver_fp.cpp | 4 +- source/module_esolver/esolver_ks.cpp | 2 +- 11 files changed, 132 insertions(+), 132 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.cpp b/source/module_basis/module_pw/module_fft/fft_cpu.cpp index a7cabae803..eb8d3d8f1a 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu.cpp @@ -87,69 +87,69 @@ void FFT_CPU::setupFFT() default: break; } - z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); - z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); - d_rspace = (double*)z_auxg; - this->planzfor = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, - (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); + z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); + z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); + d_rspace = (double*)z_auxg; + this->planzfor = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, + (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); - this->planzbac = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, - (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); + this->planzbac = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, + (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); - //--------------------------------------------------------- - // 2 D - XY - //--------------------------------------------------------- - // 1D+1D is much faster than 2D FFT! - // in-place fft is better for c2c and out-of-place fft is better for c2r - int* embed = nullptr; - int npy = this->nplane * this->ny; - if (this->xprime) + //--------------------------------------------------------- + // 2 D - XY + //--------------------------------------------------------- + // 1D+1D is much faster than 2D FFT! + // in-place fft is better for c2c and out-of-place fft is better for c2r + int* embed = nullptr; + int npy = this->nplane * this->ny; + if (this->xprime) + { + this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, + (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_FORWARD, flag); + this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, + (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_BACKWARD, flag); + if (this->gamma_only) + { + this->planxr2c = fftw_plan_many_dft_r2c(1, &this->nx, npy, d_rspace, embed, npy, 1, (fftw_complex*)z_auxr, + embed, npy, 1, flag); + this->planxc2r = fftw_plan_many_dft_c2r(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, d_rspace, + embed, npy, 1, flag); + } + else { - this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, - (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_FORWARD, flag); - this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, - (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_BACKWARD, flag); - if (this->gamma_only) - { - this->planxr2c = fftw_plan_many_dft_r2c(1, &this->nx, npy, d_rspace, embed, npy, 1, (fftw_complex*)z_auxr, - embed, npy, 1, flag); - this->planxc2r = fftw_plan_many_dft_c2r(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, d_rspace, - embed, npy, 1, flag); - } - else - { - this->planxfor1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, - (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, - (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - } + this->planxfor1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, + (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planxbac1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, + (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + } + } + else + { + this->planxfor1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, + 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planxbac1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, + 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + if (this->gamma_only) + { + this->planyr2c = fftw_plan_many_dft_r2c(1, &this->ny, this->nplane, d_rspace, embed, this->nplane, 1, + (fftw_complex*)z_auxr, embed, this->nplane, 1, flag); + this->planyc2r = fftw_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, + this->nplane, 1, d_rspace, embed, this->nplane, 1, flag); } else { - this->planxfor1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, - 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, - 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - if (this->gamma_only) - { - this->planyr2c = fftw_plan_many_dft_r2c(1, &this->ny, this->nplane, d_rspace, embed, this->nplane, 1, - (fftw_complex*)z_auxr, embed, this->nplane, 1, flag); - this->planyc2r = fftw_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, - this->nplane, 1, d_rspace, embed, this->nplane, 1, flag); - } - else - { - this->planxfor2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, - npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, - npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, - 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); - this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, - 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); - } + this->planxfor2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, + npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planxbac2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, + npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, + 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); + this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, + 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); } + } return; } diff --git a/source/module_basis/module_pw/pw_basis.cpp b/source/module_basis/module_pw/pw_basis.cpp index 4bca6baed3..ac02c45763 100644 --- a/source/module_basis/module_pw/pw_basis.cpp +++ b/source/module_basis/module_pw/pw_basis.cpp @@ -17,7 +17,7 @@ PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::mo classname="PW_Basis"; this->ft.set_device(this->device); this->ft.set_precision(this->precision); - this->ft1.setfft("cpu",this->precision); + this->fft_bundle.setfft("cpu",this->precision); } PW_Basis:: ~PW_Basis() @@ -58,19 +58,19 @@ void PW_Basis::setuptransform() this->distribute_g(); this->getstartgr(); this->ft.clear(); - this->ft1.clear(); + this->fft_bundle.clear(); if(this->xprime) { this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); - this->ft1.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); } else { this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); - this->ft1.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); } this->ft.setupFFT(); - this->ft1.setupFFT(); + this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index b5f2b58827..66f5ff6301 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -243,7 +243,7 @@ class PW_Basis int nmaxgr=0; // Gamma_only: max between npw and (nrxx+1)/2, others: max between npw and nrxx // Thus complex[nmaxgr] is able to contain either reciprocal or real data FFT ft; - FFT_Bundle ft1; + FFT_Bundle fft_bundle; //The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform). template diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index 7f6413cf1f..3e6e6b4893 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -12,7 +12,7 @@ namespace ModulePW PW_Basis_K::PW_Basis_K() { classname="PW_Basis_K"; - this->ft1.setfft("cpu",this->precision); + this->fft_bundle.setfft("cpu",this->precision); } PW_Basis_K::~PW_Basis_K() { @@ -181,16 +181,16 @@ void PW_Basis_K::setuptransform() this->getstartgr(); this->setupIndGk(); this->ft.clear(); - this->ft1.clear(); + this->fft_bundle.clear(); if(this->xprime){ this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); - this->ft1.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); }else{ this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); - this->ft1.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); + this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); } this->ft.setupFFT(); - this->ft1.setupFFT(); + this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } diff --git a/source/module_basis/module_pw/pw_basis_sup.cpp b/source/module_basis/module_pw/pw_basis_sup.cpp index e174ce2ed4..80c7e87f57 100644 --- a/source/module_basis/module_pw/pw_basis_sup.cpp +++ b/source/module_basis/module_pw/pw_basis_sup.cpp @@ -20,7 +20,7 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->distribute_g(pw_rho); this->getstartgr(); this->ft.clear(); - this->ft1.clear(); + this->fft_bundle.clear(); if (this->xprime) { this->ft.initfft(this->nx, @@ -33,7 +33,7 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->poolnproc, this->gamma_only, this->xprime); - this->ft1.initfft(this->nx, + this->fft_bundle.initfft(this->nx, this->ny, this->nz, this->lix, @@ -56,7 +56,7 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->poolnproc, this->gamma_only, this->xprime); - this->ft1.initfft(this->nx, + this->fft_bundle.initfft(this->nx, this->ny, this->nz, this->liy, @@ -68,7 +68,7 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho) this->xprime); } this->ft.setupFFT(); - this->ft1.setupFFT(); + this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); } diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp index c7590446a7..d8534c7f0a 100644 --- a/source/module_basis/module_pw/pw_transform.cpp +++ b/source/module_basis/module_pw/pw_transform.cpp @@ -29,13 +29,13 @@ void PW_Basis::real2recip(const std::complex* in, #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - this->ft1.get_auxr_data()[ir] = in[ir]; + this->fft_bundle.get_auxr_data()[ir] = in[ir]; } - this->ft1.fftxyfor(ft1.get_auxr_data(),ft1.get_auxr_data()); + this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(),fft_bundle.get_auxr_data()); - this->gatherp_scatters(this->ft1.get_auxr_data(), this->ft1.get_auxg_data()); + this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); - this->ft1.fftzfor(ft1.get_auxg_data(),ft1.get_auxg_data()); + this->fft_bundle.fftzfor(fft_bundle.get_auxg_data(),fft_bundle.get_auxg_data()); if(add) { @@ -45,7 +45,7 @@ void PW_Basis::real2recip(const std::complex* in, #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] += tmpfac * this->ft1.get_auxg_data()[this->ig2isz[ig]]; + out[ig] += tmpfac * this->fft_bundle.get_auxg_data()[this->ig2isz[ig]]; } } else @@ -56,7 +56,7 @@ void PW_Basis::real2recip(const std::complex* in, #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] = tmpfac * this->ft1.get_auxg_data()[this->ig2isz[ig]]; + out[ig] = tmpfac * this->fft_bundle.get_auxg_data()[this->ig2isz[ig]]; } } ModuleBase::timer::tick(this->classname, "real2recip"); @@ -83,11 +83,11 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo { for(int ipy = 0 ; ipy < npy ; ++ipy) { - this->ft1.get_rspace_data()[ix*npy + ipy] = in[ix*npy + ipy]; + this->fft_bundle.get_rspace_data()[ix*npy + ipy] = in[ix*npy + ipy]; } } - this->ft1.fftxyr2c(ft1.get_rspace_data(),ft1.get_auxr_data()); + this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data(),fft_bundle.get_auxr_data()); } else { @@ -96,13 +96,13 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - this->ft1.get_auxr_data()[ir] = std::complex(in[ir],0); + this->fft_bundle.get_auxr_data()[ir] = std::complex(in[ir],0); } - this->ft1.fftxyfor(ft1.get_auxr_data(),ft1.get_auxr_data()); + this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(),fft_bundle.get_auxr_data()); } - this->gatherp_scatters(this->ft1.get_auxr_data(), this->ft1.get_auxg_data()); + this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); - this->ft1.fftzfor(ft1.get_auxg_data(),ft1.get_auxg_data()); + this->fft_bundle.fftzfor(fft_bundle.get_auxg_data(),fft_bundle.get_auxg_data()); if(add) { @@ -112,7 +112,7 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] += tmpfac * this->ft1.get_auxg_data()[this->ig2isz[ig]]; + out[ig] += tmpfac * this->fft_bundle.get_auxg_data()[this->ig2isz[ig]]; } } else @@ -123,7 +123,7 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - out[ig] = tmpfac * this->ft1.get_auxg_data()[this->ig2isz[ig]]; + out[ig] = tmpfac * this->fft_bundle.get_auxg_data()[this->ig2isz[ig]]; } } ModuleBase::timer::tick(this->classname, "real2recip"); @@ -149,7 +149,7 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int i = 0 ; i < this->nst * this->nz ; ++i) { - ft1.get_auxg_data()[i] = std::complex(0, 0); + fft_bundle.get_auxg_data()[i] = std::complex(0, 0); } #ifdef _OPENMP @@ -157,13 +157,13 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - this->ft1.get_auxg_data()[this->ig2isz[ig]] = in[ig]; + this->fft_bundle.get_auxg_data()[this->ig2isz[ig]] = in[ig]; } - this->ft1.fftzbac(ft1.get_auxg_data(), ft1.get_auxg_data()); + this->fft_bundle.fftzbac(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); - this->gathers_scatterp(this->ft1.get_auxg_data(),this->ft1.get_auxr_data()); + this->gathers_scatterp(this->fft_bundle.get_auxg_data(),this->fft_bundle.get_auxr_data()); - this->ft1.fftxybac(ft1.get_auxr_data(),ft1.get_auxr_data()); + this->fft_bundle.fftxybac(fft_bundle.get_auxr_data(),fft_bundle.get_auxr_data()); if(add) { @@ -172,7 +172,7 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] += factor * this->ft1.get_auxr_data()[ir]; + out[ir] += factor * this->fft_bundle.get_auxr_data()[ir]; } } else @@ -182,7 +182,7 @@ void PW_Basis::recip2real(const std::complex* in, #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] = this->ft1.get_auxr_data()[ir]; + out[ir] = this->fft_bundle.get_auxr_data()[ir]; } } ModuleBase::timer::tick(this->classname, "recip2real"); @@ -204,7 +204,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int i = 0 ; i < this->nst * this->nz ; ++i) { - ft1.get_auxg_data()[i] = std::complex(0, 0); + fft_bundle.get_auxg_data()[i] = std::complex(0, 0); } #ifdef _OPENMP @@ -212,15 +212,15 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int ig = 0 ; ig < this->npw ; ++ig) { - this->ft1.get_auxg_data()[this->ig2isz[ig]] = in[ig]; + this->fft_bundle.get_auxg_data()[this->ig2isz[ig]] = in[ig]; } - this->ft1.fftzbac(ft1.get_auxg_data(), ft1.get_auxg_data()); + this->fft_bundle.fftzbac(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); - this->gathers_scatterp(this->ft1.get_auxg_data(), this->ft1.get_auxr_data()); + this->gathers_scatterp(this->fft_bundle.get_auxg_data(), this->fft_bundle.get_auxr_data()); if(this->gamma_only) { - this->ft1.fftxyc2r(ft1.get_auxr_data(),ft1.get_rspace_data()); + this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data(),fft_bundle.get_rspace_data()); // r2c in place const int npy = this->ny * this->nplane; @@ -234,7 +234,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo { for(int ipy = 0 ; ipy < npy ; ++ipy) { - out[ix*npy + ipy] += factor * this->ft1.get_rspace_data()[ix*npy + ipy]; + out[ix*npy + ipy] += factor * this->fft_bundle.get_rspace_data()[ix*npy + ipy]; } } } @@ -247,14 +247,14 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo { for(int ipy = 0 ; ipy < npy ; ++ipy) { - out[ix*npy + ipy] = this->ft1.get_rspace_data()[ix*npy + ipy]; + out[ix*npy + ipy] = this->fft_bundle.get_rspace_data()[ix*npy + ipy]; } } } } else { - this->ft1.fftxybac(ft1.get_auxr_data(),ft1.get_auxr_data()); + this->fft_bundle.fftxybac(fft_bundle.get_auxr_data(),fft_bundle.get_auxr_data()); if(add) { #ifdef _OPENMP @@ -262,7 +262,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] += factor * this->ft1.get_auxr_data()[ir].real(); + out[ir] += factor * this->fft_bundle.get_auxr_data()[ir].real(); } } else @@ -272,7 +272,7 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #endif for(int ir = 0 ; ir < this->nrxx ; ++ir) { - out[ir] = this->ft1.get_auxr_data()[ir].real(); + out[ir] = this->fft_bundle.get_auxr_data()[ir].real(); } } } diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index 978819501d..88285df119 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -32,7 +32,7 @@ void PW_Basis_K::real2recip(const std::complex* in, ModuleBase::timer::tick(this->classname, "real2recip"); assert(this->gamma_only == false); - auto* auxr = this->ft1.get_auxr_data(); + auto* auxr = this->fft_bundle.get_auxr_data(); #ifdef _OPENMP #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif @@ -40,15 +40,15 @@ void PW_Basis_K::real2recip(const std::complex* in, { auxr[ir] = in[ir]; } - this->ft1.fftxyfor(ft1.get_auxr_data(), ft1.get_auxr_data()); + this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); - this->gatherp_scatters(this->ft1.get_auxr_data(), this->ft1.get_auxg_data()); + this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); - this->ft1.fftzfor(ft1.get_auxg_data(), ft1.get_auxg_data()); + this->fft_bundle.fftzfor(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft1.get_auxg_data(); + auto* auxg = this->fft_bundle.get_auxg_data(); if (add) { FPTYPE tmpfac = factor / FPTYPE(this->nxyz); @@ -98,7 +98,7 @@ void PW_Basis_K::real2recip(const FPTYPE* in, assert(this->gamma_only == true); // for(int ir = 0 ; ir < this->nrxx ; ++ir) // { - // this->ft1.get_rspace_data()[ir] = in[ir]; + // this->fft_bundle.get_rspace_data()[ir] = in[ir]; // } // r2c in place const int npy = this->ny * this->nplane; @@ -109,19 +109,19 @@ void PW_Basis_K::real2recip(const FPTYPE* in, { for (int ipy = 0; ipy < npy; ++ipy) { - this->ft1.get_rspace_data()[ix * npy + ipy] = in[ix * npy + ipy]; + this->fft_bundle.get_rspace_data()[ix * npy + ipy] = in[ix * npy + ipy]; } } - this->ft1.fftxyr2c(ft1.get_rspace_data(), ft1.get_auxr_data()); + this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data(), fft_bundle.get_auxr_data()); - this->gatherp_scatters(this->ft1.get_auxr_data(), this->ft1.get_auxg_data()); + this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); - this->ft1.fftzfor(ft1.get_auxg_data(), ft1.get_auxg_data()); + this->fft_bundle.fftzfor(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft1.get_auxg_data(); + auto* auxg = this->fft_bundle.get_auxg_data(); if (add) { FPTYPE tmpfac = factor / FPTYPE(this->nxyz); @@ -170,11 +170,11 @@ void PW_Basis_K::recip2real(const std::complex* in, { ModuleBase::timer::tick(this->classname, "recip2real"); assert(this->gamma_only == false); - ModuleBase::GlobalFunc::ZEROS(ft1.get_auxg_data(), this->nst * this->nz); + ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data(), this->nst * this->nz); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft1.get_auxg_data(); + auto* auxg = this->fft_bundle.get_auxg_data(); #ifdef _OPENMP #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif @@ -182,13 +182,13 @@ void PW_Basis_K::recip2real(const std::complex* in, { auxg[this->igl2isz_k[igl + startig]] = in[igl]; } - this->ft1.fftzbac(ft1.get_auxg_data(), ft1.get_auxg_data()); + this->fft_bundle.fftzbac(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); - this->gathers_scatterp(this->ft1.get_auxg_data(), this->ft1.get_auxr_data()); + this->gathers_scatterp(this->fft_bundle.get_auxg_data(), this->fft_bundle.get_auxr_data()); - this->ft1.fftxybac(ft1.get_auxr_data(), ft1.get_auxr_data()); + this->fft_bundle.fftxybac(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); - auto* auxr = this->ft1.get_auxr_data(); + auto* auxr = this->fft_bundle.get_auxr_data(); if (add) { #ifdef _OPENMP @@ -234,11 +234,11 @@ void PW_Basis_K::recip2real(const std::complex* in, { ModuleBase::timer::tick(this->classname, "recip2real"); assert(this->gamma_only == true); - ModuleBase::GlobalFunc::ZEROS(ft1.get_auxg_data(), this->nst * this->nz); + ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data(), this->nst * this->nz); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; - auto* auxg = this->ft1.get_auxg_data(); + auto* auxg = this->fft_bundle.get_auxg_data(); #ifdef _OPENMP #pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif @@ -246,20 +246,20 @@ void PW_Basis_K::recip2real(const std::complex* in, { auxg[this->igl2isz_k[igl + startig]] = in[igl]; } - this->ft1.fftzbac(ft1.get_auxg_data(), ft1.get_auxg_data()); + this->fft_bundle.fftzbac(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); - this->gathers_scatterp(this->ft1.get_auxg_data(), this->ft1.get_auxr_data()); + this->gathers_scatterp(this->fft_bundle.get_auxg_data(), this->fft_bundle.get_auxr_data()); - this->ft1.fftxyc2r(ft1.get_auxr_data(), ft1.get_rspace_data()); + this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data(), fft_bundle.get_rspace_data()); // for(int ir = 0 ; ir < this->nrxx ; ++ir) // { - // out[ir] = this->ft1.get_rspace_data()[ir] / this->nxyz; + // out[ir] = this->fft_bundle.get_rspace_data()[ir] / this->nxyz; // } // r2c in place const int npy = this->ny * this->nplane; - auto* rspace = this->ft1.get_rspace_data(); + auto* rspace = this->fft_bundle.get_rspace_data(); if (add) { #ifdef _OPENMP diff --git a/source/module_elecstate/module_charge/charge.cpp b/source/module_elecstate/module_charge/charge.cpp index dec33f0418..6a2405b579 100644 --- a/source/module_elecstate/module_charge/charge.cpp +++ b/source/module_elecstate/module_charge/charge.cpp @@ -644,10 +644,10 @@ void Charge::atomic_rho(const int spin_number_need, double sumrea = 0.0; for (int ir = 0; ir < this->rhopw->nrxx; ir++) { - rea = this->rhopw->ft1.get_auxr_data()[ir].real(); + rea = this->rhopw->fft_bundle.get_auxr_data()[ir].real(); sumrea += rea; neg += std::min(0.0, rea); - ima += std::abs(this->rhopw->ft1.get_auxr_data()[ir].imag()); + ima += std::abs(this->rhopw->fft_bundle.get_auxr_data()[ir].imag()); } #ifdef __MPI diff --git a/source/module_elecstate/module_charge/charge_init.cpp b/source/module_elecstate/module_charge/charge_init.cpp index 9efca214f9..57af45e3be 100644 --- a/source/module_elecstate/module_charge/charge_init.cpp +++ b/source/module_elecstate/module_charge/charge_init.cpp @@ -260,8 +260,8 @@ void Charge::set_rho_core( double rhoneg = 0.0; for (int ir = 0; ir < this->rhopw->nrxx; ir++) { - rhoneg += std::min(0.0, this->rhopw->ft1.get_auxr_data()[ir].real()); - rhoima += std::abs(this->rhopw->ft1.get_auxr_data()[ir].imag()); + rhoneg += std::min(0.0, this->rhopw->fft_bundle.get_auxr_data()[ir].real()); + rhoima += std::abs(this->rhopw->fft_bundle.get_auxr_data()[ir].imag()); // NOTE: Core charge is computed in reciprocal space and brought to real // space by FFT. For non smooth core charges (or insufficient cut-off) // this may result in negative values in some grid points. diff --git a/source/module_esolver/esolver_fp.cpp b/source/module_esolver/esolver_fp.cpp index 421f2e8dff..4a19c1d917 100644 --- a/source/module_esolver/esolver_fp.cpp +++ b/source/module_esolver/esolver_fp.cpp @@ -84,7 +84,7 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell) this->pw_rho->initparameters(false, 4.0 * inp.ecutwfc); this->pw_rho->ft.fft_mode = inp.fft_mode; - this->pw_rho->ft1.initfftmode(inp.fft_mode); + this->pw_rho->fft_bundle.initfftmode(inp.fft_mode); this->pw_rho->setuptransform(); this->pw_rho->collect_local_pw(); this->pw_rho->collect_uniqgg(); @@ -110,7 +110,7 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell) } this->pw_rhod->initparameters(false, inp.ecutrho); this->pw_rhod->ft.fft_mode = inp.fft_mode; - this->pw_rhod->ft1.initfftmode(inp.fft_mode); + this->pw_rhod->fft_bundle.initfftmode(inp.fft_mode); pw_rhod_sup->setuptransform(this->pw_rho); this->pw_rhod->collect_local_pw(); this->pw_rhod->collect_uniqgg(); diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index e2e8e48c42..4c5c8353dd 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -247,7 +247,7 @@ void ESolver_KS::before_all_runners(const Input_para& inp, UnitCell& #endif this->pw_wfc->ft.fft_mode = inp.fft_mode; - this->pw_wfc->ft1.initfftmode(inp.fft_mode); + this->pw_wfc->fft_bundle.initfftmode(inp.fft_mode); this->pw_wfc->setuptransform(); //! 9) initialize the number of plane waves for each k point From f07a97da1b2c131be08d5af75282a734826e57ce Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Fri, 8 Nov 2024 17:32:06 +0800 Subject: [PATCH 23/27] modify the Makefile --- source/module_basis/module_pw/test/Makefile | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/source/module_basis/module_pw/test/Makefile b/source/module_basis/module_pw/test/Makefile index b3b4fea709..884f0f74c0 100644 --- a/source/module_basis/module_pw/test/Makefile +++ b/source/module_basis/module_pw/test/Makefile @@ -94,7 +94,7 @@ endif ##========================== ## GTEST ##========================== -GTESTOPTS = -I/usr/local/gtest/include -L/home/ubuntu/desktop/github/googletest/lib -lgtest -lpthread +GTESTOPTS = -I${GTEST_DIR}/include -L${GTEST_DIR}/lib -lgtest -lpthread @@ -105,7 +105,8 @@ VPATH=../../../module_base\ ../../../module_base/module_device\ ../../../module_base/module_container/ATen/core\ ../../../module_base/module_container/ATen\ -:../ +../../../module_parameter\ +../\ MATH_OBJS0=matrix.o\ matrix3.o\ @@ -126,7 +127,11 @@ pw_transform_k.o\ memory.o\ memory_op.o\ depend_mock.o\ -tensor.o\ +parameter.o\ +fft_base.o\ +fft_bundle.o\ +fft_cpu.o\ + OTHER_OBJS0= From 9c1a22dc733ebc775d57a8b1b850e66418bf320a Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Mon, 11 Nov 2024 17:53:06 +0800 Subject: [PATCH 24/27] update the file --- .../module_pw/module_fft/fft_base.cpp | 24 +- .../module_pw/module_fft/fft_base.h | 8 +- .../module_pw/module_fft/fft_bundle.cpp | 137 ++++---- .../module_pw/module_fft/fft_bundle.h | 4 +- .../module_pw/module_fft/fft_cpu.cpp | 292 +++++++++++++----- .../module_pw/module_fft/fft_cpu.h | 7 +- .../module_pw/module_fft/fft_cpu_float.cpp | 241 ++++++++++++--- 7 files changed, 515 insertions(+), 198 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_base.cpp b/source/module_basis/module_pw/module_fft/fft_base.cpp index 8cda835f24..d87561485b 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.cpp +++ b/source/module_basis/module_pw/module_fft/fft_base.cpp @@ -1,17 +1,17 @@ #include "fft_base.h" namespace ModulePW { -template -FFT_BASE::FFT_BASE() -{ -} -template -FFT_BASE::~FFT_BASE() -{ -} +// template +// FFT_BASE::FFT_BASE() +// { +// } +// template +// FFT_BASE::~FFT_BASE() +// { +// } -template FFT_BASE::FFT_BASE(); -template FFT_BASE::FFT_BASE(); -template FFT_BASE::~FFT_BASE(); -template FFT_BASE::~FFT_BASE(); +// template FFT_BASE::FFT_BASE(); +// template FFT_BASE::FFT_BASE(); +// template FFT_BASE::~FFT_BASE(); +// template FFT_BASE::~FFT_BASE(); } \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index 3997079898..529e7838a8 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -10,8 +10,8 @@ class FFT_BASE { public: - FFT_BASE(); - virtual ~FFT_BASE(); + FFT_BASE(){}; + virtual ~FFT_BASE(){}; /** * @brief Initialize the fft parameters As virtual function. @@ -159,5 +159,9 @@ class FFT_BASE int ny=0; int nz=0; }; +template FFT_BASE::FFT_BASE(); +template FFT_BASE::FFT_BASE(); +template FFT_BASE::~FFT_BASE(); +template FFT_BASE::~FFT_BASE(); } #endif // FFT_BASE_H diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 7c084c7be9..6d71df1c21 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -79,11 +79,29 @@ void FFT_Bundle::initfft(int nx_in, fft_double = make_unique>(this->fft_mode); if (float_flag) { - fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in); + fft_float->initfft(nx_in, + ny_in, + nz_in, + lixy_in, + rixy_in, + ns_in, + nplane_in, + nproc_in, + gamma_only_in, + xprime_in); } if (double_flag) { - fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in); + fft_double->initfft(nx_in, + ny_in, + nz_in, + lixy_in, + rixy_in, + ns_in, + nplane_in, + nproc_in, + gamma_only_in, + xprime_in); } } if (device=="gpu") @@ -138,133 +156,134 @@ void FFT_Bundle::clear() fft_double->clear(); } } -// access the real space data -template <> -float* FFT_Bundle::get_rspace_data() const -{ - return fft_float->get_rspace_data(); -} + template <> -double* FFT_Bundle::get_rspace_data() const -{ - return fft_double->get_rspace_data(); -} -template <> -std::complex* FFT_Bundle::get_auxr_data() const -{ - return fft_float->get_auxr_data(); -} -template <> -std::complex* FFT_Bundle::get_auxr_data() const -{ - return fft_double->get_auxr_data(); -} -template <> -std::complex* FFT_Bundle::get_auxg_data() const -{ - return fft_float->get_auxg_data(); -} -template <> -std::complex* FFT_Bundle::get_auxg_data() const -{ - return fft_double->get_auxg_data(); -} -template <> -std::complex* FFT_Bundle::get_auxr_3d_data() const -{ - return fft_float->get_auxr_3d_data(); -} -template <> -std::complex* FFT_Bundle::get_auxr_3d_data() const -{ - return fft_double->get_auxr_3d_data(); -} -template <> -void FFT_Bundle::fftxyfor(std::complex* in, std::complex* out) const +void FFT_Bundle::fftxyfor(std::complex* in, + std::complex* out) const { fft_float->fftxyfor(in,out); } template <> -void FFT_Bundle::fftxyfor(std::complex* in, std::complex* out) const +void FFT_Bundle::fftxyfor(std::complex* in, + std::complex* out) const { fft_double->fftxyfor(in,out); } template <> -void FFT_Bundle::fftzfor(std::complex* in, std::complex* out) const +void FFT_Bundle::fftzfor(std::complex* in, + std::complex* out) const { fft_float->fftzfor(in,out); } template <> -void FFT_Bundle::fftzfor(std::complex* in, std::complex* out) const +void FFT_Bundle::fftzfor(std::complex* in, + std::complex* out) const { fft_double->fftzfor(in,out); } template <> -void FFT_Bundle::fftxybac(std::complex* in, std::complex* out) const +void FFT_Bundle::fftxybac(std::complex* in, + std::complex* out) const { fft_float->fftxybac(in,out); } template <> -void FFT_Bundle::fftxybac(std::complex* in, std::complex* out) const +void FFT_Bundle::fftxybac(std::complex* in, + std::complex* out) const { fft_double->fftxybac(in,out); } template <> -void FFT_Bundle::fftzbac(std::complex* in, std::complex* out) const +void FFT_Bundle::fftzbac(std::complex* in, + std::complex* out) const { fft_float->fftzbac(in,out); } template <> -void FFT_Bundle::fftzbac(std::complex* in, std::complex* out) const +void FFT_Bundle::fftzbac(std::complex* in, + std::complex* out) const { fft_double->fftzbac(in,out); } template <> -void FFT_Bundle::fftxyr2c(float* in, std::complex* out) const +void FFT_Bundle::fftxyr2c(float* in, + std::complex* out) const { fft_float->fftxyr2c(in,out); } template <> -void FFT_Bundle::fftxyr2c(double* in, std::complex* out) const +void FFT_Bundle::fftxyr2c(double* in, + std::complex* out) const { fft_double->fftxyr2c(in,out); } template <> -void FFT_Bundle::fftxyc2r(std::complex* in, float* out) const +void FFT_Bundle::fftxyc2r(std::complex* in, + float* out) const { fft_float->fftxyc2r(in,out); } template <> -void FFT_Bundle::fftxyc2r(std::complex* in, double* out) const +void FFT_Bundle::fftxyc2r(std::complex* in, + double* out) const { fft_double->fftxyc2r(in,out); } template <> -void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const { fft_float->fft3D_forward(in, out); } template <> -void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const { fft_double->fft3D_forward(in, out); } template <> -void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const { fft_float->fft3D_backward(in, out); } template <> -void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, std::complex* out) const +void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const { fft_double->fft3D_backward(in, out); } + +// access the real space data +template <> float* +FFT_Bundle::get_rspace_data() const {return fft_float->get_rspace_data();} +template <> double* +FFT_Bundle::get_rspace_data() const {return fft_double->get_rspace_data();} + +template <> std::complex* +FFT_Bundle::get_auxr_data() const {return fft_float->get_auxr_data();} +template <> std::complex* +FFT_Bundle::get_auxr_data() const{return fft_double->get_auxr_data();} + +template <> std::complex* +FFT_Bundle::get_auxg_data() const{return fft_float->get_auxg_data();} +template <> std::complex* +FFT_Bundle::get_auxg_data() const{return fft_double->get_auxg_data();} + +template <> std::complex* +FFT_Bundle::get_auxr_3d_data() const{return fft_float->get_auxr_3d_data();} +template <> std::complex* +FFT_Bundle::get_auxr_3d_data() const {return fft_double->get_auxr_3d_data();} } \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 43ae8cdd16..ebeb04f5b8 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -206,8 +206,8 @@ class FFT_Bundle bool float_flag=false; bool float_define=true; bool double_flag=false; - std::shared_ptr> fft_float=nullptr; - std::shared_ptr> fft_double=nullptr; + std::unique_ptr> fft_float=nullptr; + std::unique_ptr> fft_double=nullptr; std::string device = "cpu"; std::string precision = "double"; diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.cpp b/source/module_basis/module_pw/module_fft/fft_cpu.cpp index eb8d3d8f1a..f06a235a2b 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu.cpp @@ -2,32 +2,6 @@ #include "fftw3.h" namespace ModulePW { -template <> -FFT_CPU::FFT_CPU() -{ -} -template <> -FFT_CPU::~FFT_CPU() -{ -} -template <> -FFT_CPU::FFT_CPU() -{ -} -template <> -FFT_CPU::~FFT_CPU() -{ -} -template <> -FFT_CPU::FFT_CPU(const int fft_mode_in) -{ - this->fft_mode = fft_mode_in; -} -template <> -FFT_CPU::FFT_CPU(const int fft_mode_in) -{ - this->fft_mode = fft_mode_in; -} template void FFT_CPU::initfft(int nx_in, @@ -90,11 +64,33 @@ void FFT_CPU::setupFFT() z_auxg = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); z_auxr = (std::complex*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids); d_rspace = (double*)z_auxg; - this->planzfor = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, - (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); + this->planzfor = fftw_plan_many_dft(1, + &this->nz, + this->ns, + (fftw_complex*)z_auxg, + &this->nz, + 1, + this->nz, + (fftw_complex*)z_auxg, + &this->nz, + 1, + this->nz, + FFTW_FORWARD, + flag); - this->planzbac = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz, - (fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); + this->planzbac = fftw_plan_many_dft(1, + &this->nz, + this->ns, + (fftw_complex*)z_auxg, + &this->nz, + 1, + this->nz, + (fftw_complex*)z_auxg, + &this->nz, + 1, + this->nz, + FFTW_BACKWARD, + flag); //--------------------------------------------------------- // 2 D - XY @@ -105,49 +101,192 @@ void FFT_CPU::setupFFT() int npy = this->nplane * this->ny; if (this->xprime) { - this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, - (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_FORWARD, flag); - this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1, - (fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_BACKWARD, flag); + this->planyfor = fftw_plan_many_dft(1, + &this->ny, + this->nplane, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + FFTW_FORWARD, + flag); + this->planybac = fftw_plan_many_dft(1, + &this->ny, + this->nplane, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + FFTW_BACKWARD, + flag); if (this->gamma_only) { - this->planxr2c = fftw_plan_many_dft_r2c(1, &this->nx, npy, d_rspace, embed, npy, 1, (fftw_complex*)z_auxr, - embed, npy, 1, flag); - this->planxc2r = fftw_plan_many_dft_c2r(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, d_rspace, - embed, npy, 1, flag); + this->planxr2c = fftw_plan_many_dft_r2c(1, + &this->nx, + npy, + d_rspace, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + flag); + this->planxc2r = fftw_plan_many_dft_c2r(1, + &this->nx, + npy, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + d_rspace, + embed, + npy, + 1, + flag); } else { - this->planxfor1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, - (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, - (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planxfor1 = fftw_plan_many_dft(1, + &this->nx, + npy, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planxbac1 = fftw_plan_many_dft(1, + &this->nx, + npy, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); } } else { - this->planxfor1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, - 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy, - 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planxfor1 = fftw_plan_many_dft(1, + &this->nx, + this->nplane * (this->lixy + 1), + (fftw_complex*)z_auxr, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planxbac1 = fftw_plan_many_dft(1, + &this->nx, + this->nplane * (this->lixy + 1), + (fftw_complex*)z_auxr, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); if (this->gamma_only) { - this->planyr2c = fftw_plan_many_dft_r2c(1, &this->ny, this->nplane, d_rspace, embed, this->nplane, 1, - (fftw_complex*)z_auxr, embed, this->nplane, 1, flag); - this->planyc2r = fftw_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, - this->nplane, 1, d_rspace, embed, this->nplane, 1, flag); + this->planyr2c = fftw_plan_many_dft_r2c(1, + &this->ny, + this->nplane, + d_rspace, + embed, + this->nplane, + 1, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + flag); + this->planyc2r = fftw_plan_many_dft_c2r(1, + &this->ny, + this->nplane, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + d_rspace, + embed, + this->nplane, + 1, + flag); } else { - - this->planxfor2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, - npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planxbac2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed, - npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, - 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); - this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane, - 1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); + this->planxfor2 = fftw_plan_many_dft(1, + &this->nx, + this->nplane * (this->ny - this->rixy), + (fftw_complex*)z_auxr, + embed, + npy, + 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + this->planxbac2 = fftw_plan_many_dft(1, + &this->nx, + this->nplane * (this->ny - this->rixy), + (fftw_complex*)z_auxr, + embed, + npy, + 1, + (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); + this->planyfor = fftw_plan_many_dft(1, + &this->ny, + this->nplane, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + FFTW_FORWARD, + flag); + this->planybac = fftw_plan_many_dft(1, + &this->ny, + this->nplane, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + (fftw_complex*)z_auxr, + embed, + this->nplane, + 1, + FFTW_BACKWARD, + flag); } } return; @@ -159,7 +298,7 @@ void FFT_CPU::clearfft(fftw_plan& plan) if (plan) { fftw_destroy_plan(plan); - plan = NULL; + plan = nullptr; } } @@ -198,21 +337,6 @@ void FFT_CPU::clear() d_rspace = nullptr; } -template <> -double* FFT_CPU::get_rspace_data() const -{ - return d_rspace; -} -template <> -std::complex* FFT_CPU::get_auxr_data() const -{ - return z_auxr; -} -template <> -std::complex* FFT_CPU::get_auxg_data() const -{ - return z_auxg; -} template <> void FFT_CPU::fftxyfor(std::complex* in, std::complex* out) const { @@ -220,7 +344,6 @@ void FFT_CPU::fftxyfor(std::complex* in, std::complex* o if (this->xprime) { fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); - for (int i = 0; i < this->lixy + 1; ++i) { fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); @@ -236,11 +359,11 @@ void FFT_CPU::fftxyfor(std::complex* in, std::complex* o { fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); } - fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out); fftw_execute_dft(this->planxfor2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); } } + template <> void FFT_CPU::fftxybac(std::complex* in,std::complex* out) const { @@ -261,23 +384,25 @@ void FFT_CPU::fftxybac(std::complex* in,std::complex* ou { fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out); fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]); - for (int i = 0; i < this->nx; ++i) { fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]); } } } + template <> void FFT_CPU::fftzfor(std::complex* in, std::complex* out) const { fftw_execute_dft(this->planzfor, (fftw_complex*)in, (fftw_complex*)out); } + template <> void FFT_CPU::fftzbac(std::complex* in, std::complex* out) const { fftw_execute_dft(this->planzbac, (fftw_complex*)in, (fftw_complex*)out); } + template <> void FFT_CPU::fftxyr2c(double* in, std::complex* out) const { @@ -285,7 +410,6 @@ void FFT_CPU::fftxyr2c(double* in, std::complex* out) const if (this->xprime) { fftw_execute_dft_r2c(this->planxr2c, in, (fftw_complex*)out); - for (int i = 0; i < this->lixy + 1; ++i) { fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]); @@ -297,7 +421,6 @@ void FFT_CPU::fftxyr2c(double* in, std::complex* out) const { fftw_execute_dft_r2c(this->planyr2c, &in[i * npy], (fftw_complex*)&out[i * npy]); } - fftw_execute_dft(this->planxfor1, (fftw_complex*)out, (fftw_complex*)out); } } @@ -312,13 +435,11 @@ void FFT_CPU::fftxyc2r(std::complex *in,double *out) const { fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&in[i * npy]); } - fftw_execute_dft_c2r(this->planxc2r, (fftw_complex*)in, out); } else { fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)in); - for (int i = 0; i < this->nx; ++i) { fftw_execute_dft_c2r(this->planyc2r, (fftw_complex*)&in[i * npy], &out[i * npy]); @@ -326,6 +447,13 @@ void FFT_CPU::fftxyc2r(std::complex *in,double *out) const } } +template <> double* +FFT_CPU::get_rspace_data() const {return d_rspace;} +template <> std::complex* +FFT_CPU::get_auxr_data() const {return z_auxr;} +template <> std::complex* +FFT_CPU::get_auxg_data() const {return z_auxg;} + template FFT_CPU::FFT_CPU(); template FFT_CPU::~FFT_CPU(); template FFT_CPU::FFT_CPU(); diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.h b/source/module_basis/module_pw/module_fft/fft_cpu.h index 1db61bf841..888880b386 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.h +++ b/source/module_basis/module_pw/module_fft/fft_cpu.h @@ -13,9 +13,9 @@ template class FFT_CPU : public FFT_BASE { public: - FFT_CPU(); - FFT_CPU(const int fft_mode_in); - ~FFT_CPU(); + FFT_CPU(){}; + FFT_CPU(const int fft_mode_in){this->fft_mode = fft_mode_in;}; + ~FFT_CPU(){}; /** * @brief Initialize the fft parameters. @@ -158,7 +158,6 @@ class FFT_CPU : public FFT_BASE * @brief rixy: the right edge of the pw ball in the x or y direction */ int rixy=0; - /** * @brief xprime: whether xprime is used,when do recip2real, x-fft will * be done last and when doing real2recip, x-fft will be done first; diff --git a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp index 627a2aa2f8..358e2dfa88 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp @@ -30,16 +30,39 @@ void FFT_CPU::setupFFT() // 1 D //--------------------------------------------------------- - // fftw_plan_many_dft(int rank, const int *n, int howmany, + // fftw_plan_many_dft(int rank, + // const int *n, int howmany, // fftw_complex *in, const int *inembed, int istride, int idist, // fftw_complex *out, const int *onembed, int ostride, int odist, int sign, unsigned //flags); - this->planfzfor = fftwf_plan_many_dft(1, &this->nz, this->ns, (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, - (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag); + this->planfzfor = fftwf_plan_many_dft(1, + &this->nz, + this->ns, + (fftwf_complex*)c_auxg, + &this->nz, + 1, + this->nz, + (fftwf_complex*)c_auxg, + &this->nz, + 1, + this->nz, + FFTW_FORWARD, + flag); - this->planfzbac = fftwf_plan_many_dft(1, &this->nz, this->ns, (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, - (fftwf_complex*)c_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag); + this->planfzbac = fftwf_plan_many_dft(1, + &this->nz, + this->ns, + (fftwf_complex*)c_auxg, + &this->nz, + 1, + this->nz, + (fftwf_complex*)c_auxg, + &this->nz, + 1, + this->nz, + FFTW_BACKWARD, + flag); //--------------------------------------------------------- // 2 D //--------------------------------------------------------- @@ -48,52 +71,196 @@ void FFT_CPU::setupFFT() int npy = this->nplane * this->ny; if (this->xprime) { - this->planfyfor = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, nplane, 1, - (fftwf_complex*)c_auxr, embed, nplane, 1, FFTW_FORWARD, flag); - this->planfybac = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, nplane, 1, - (fftwf_complex*)c_auxr, embed, nplane, 1, FFTW_BACKWARD, flag); + this->planfyfor = fftwf_plan_many_dft(1, + &this->ny, + this->nplane, + (fftwf_complex*)c_auxr, + embed, + nplane, + 1, + (fftwf_complex*)c_auxr, + embed, + nplane, + 1, + FFTW_FORWARD, + flag); + this->planfybac = fftwf_plan_many_dft(1, + &this->ny, + this->nplane, + (fftwf_complex*)c_auxr, + embed, + nplane, + 1, + (fftwf_complex*)c_auxr, + embed, nplane, + 1, + FFTW_BACKWARD, + flag); if (this->gamma_only) { - this->planfxr2c = fftwf_plan_many_dft_r2c(1, &this->nx, npy, s_rspace, embed, npy, 1, - (fftwf_complex*)c_auxr, embed, npy, 1, flag); - this->planfxc2r = fftwf_plan_many_dft_c2r(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, - s_rspace, embed, npy, 1, flag); + this->planfxr2c = fftwf_plan_many_dft_r2c(1, + &this->nx, + npy, + s_rspace, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, npy, + 1, + flag); + this->planfxc2r = fftwf_plan_many_dft_c2r(1, + &this->nx, + npy, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + s_rspace, + embed, + npy, + 1, + flag); } else { - this->planfxfor1 = fftwf_plan_many_dft(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, - (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planfxbac1 = fftwf_plan_many_dft(1, &this->nx, npy, (fftwf_complex*)c_auxr, embed, npy, 1, - (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planfxfor1 = fftwf_plan_many_dft(1, + &this->nx, + npy, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planfxbac1 = fftwf_plan_many_dft(1, + &this->nx, + npy, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); } } else { - this->planfxfor1 = fftwf_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planfxbac1 = fftwf_plan_many_dft(1, &this->nx, this->nplane * (lixy + 1), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); + this->planfxfor1 = fftwf_plan_many_dft(1, + &this->nx, + this->nplane * (lixy + 1), + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planfxbac1 = fftwf_plan_many_dft(1, + &this->nx, + this->nplane * (lixy + 1), + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); if (this->gamma_only) { - this->planfyr2c = fftwf_plan_many_dft_r2c(1, &this->ny, this->nplane, s_rspace, embed, this->nplane, 1, - (fftwf_complex*)c_auxr, embed, this->nplane, 1, flag); - this->planfyc2r = fftwf_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, - this->nplane, 1, s_rspace, embed, this->nplane, 1, flag); + this->planfyr2c = fftwf_plan_many_dft_r2c(1, + &this->ny, + this->nplane, + s_rspace, + embed, + this->nplane, + 1, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + flag); + this->planfyc2r = fftwf_plan_many_dft_c2r(1, + &this->ny, + this->nplane, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + s_rspace, + embed, + this->nplane, + 1, + flag); } else { - this->planfxfor2 - = fftwf_plan_many_dft(1, &this->nx, this->nplane * (this->ny - rixy), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_FORWARD, flag); - this->planfxbac2 - = fftwf_plan_many_dft(1, &this->nx, this->nplane * (this->ny - rixy), (fftwf_complex*)c_auxr, embed, - npy, 1, (fftwf_complex*)c_auxr, embed, npy, 1, FFTW_BACKWARD, flag); - this->planfyfor - = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, this->nplane, 1, - (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag); - this->planfybac - = fftwf_plan_many_dft(1, &this->ny, this->nplane, (fftwf_complex*)c_auxr, embed, this->nplane, 1, - (fftwf_complex*)c_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag); + this->planfxfor2 = fftwf_plan_many_dft(1, + &this->nx, + this->nplane * (this->ny - rixy), + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); + this->planfxbac2 = fftwf_plan_many_dft(1, + &this->nx, + this->nplane * (this->ny - rixy), + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + (fftwf_complex*)c_auxr, + embed, + npy, + 1, + FFTW_BACKWARD, + flag); + this->planfyfor = fftwf_plan_many_dft(1, + &this->ny, + this->nplane, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + FFTW_FORWARD, + flag); + this->planfybac = fftwf_plan_many_dft(1, + &this->ny, + this->nplane, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + (fftwf_complex*)c_auxr, + embed, + this->nplane, + 1, + FFTW_BACKWARD, + flag); } } return; From 90294534cff3beed6b118c50be1bad41a9d77106 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Mon, 11 Nov 2024 18:19:33 +0800 Subject: [PATCH 25/27] update the format --- .../module_pw/module_fft/fft_base.cpp | 17 +- .../module_pw/module_fft/fft_base.h | 4 - .../module_pw/module_fft/fft_bundle.cpp | 261 ++++++------------ .../module_pw/module_fft/fft_bundle.h | 16 +- .../module_pw/module_fft/fft_cpu.cpp | 1 - .../module_pw/module_fft/fft_cpu.h | 6 +- 6 files changed, 105 insertions(+), 200 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_base.cpp b/source/module_basis/module_pw/module_fft/fft_base.cpp index d87561485b..4c91d4d7b4 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.cpp +++ b/source/module_basis/module_pw/module_fft/fft_base.cpp @@ -1,17 +1,8 @@ #include "fft_base.h" namespace ModulePW { -// template -// FFT_BASE::FFT_BASE() -// { -// } -// template -// FFT_BASE::~FFT_BASE() -// { -// } - -// template FFT_BASE::FFT_BASE(); -// template FFT_BASE::FFT_BASE(); -// template FFT_BASE::~FFT_BASE(); -// template FFT_BASE::~FFT_BASE(); +template FFT_BASE::FFT_BASE(); +template FFT_BASE::FFT_BASE(); +template FFT_BASE::~FFT_BASE(); +template FFT_BASE::~FFT_BASE(); } \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index 529e7838a8..a8f4b246aa 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -159,9 +159,5 @@ class FFT_BASE int ny=0; int nz=0; }; -template FFT_BASE::FFT_BASE(); -template FFT_BASE::FFT_BASE(); -template FFT_BASE::~FFT_BASE(); -template FFT_BASE::~FFT_BASE(); } #endif // FFT_BASE_H diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 6d71df1c21..1e82e0c595 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -16,38 +16,12 @@ std::unique_ptr make_unique(Args &&... args) } namespace ModulePW { -FFT_Bundle::FFT_Bundle() -{ -} -FFT_Bundle::FFT_Bundle(std::string device_in,std::string precision_in) -{ - assert(device_in=="cpu" || device_in=="gpu"); - assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); - this->device = device_in; - this->precision = precision_in; -} - -FFT_Bundle::~FFT_Bundle() -{ -} - -void FFT_Bundle::set_device(std::string device_in) -{ - this->device = device_in; -} - -void FFT_Bundle::set_precision(std::string precision_in) -{ - this->precision = precision_in; -} void FFT_Bundle::setfft(std::string device_in,std::string precision_in) { - assert(device_in=="cpu" || device_in=="gpu"); - assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing"); this->device = device_in; this->precision = precision_in; - } + void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, @@ -60,6 +34,9 @@ void FFT_Bundle::initfft(int nx_in, bool xprime_in , bool mpifft_in) { + assert(this->device=="cpu" || this->device=="gpu"); + assert(this->precision=="single" || this->precision=="double" || this->precision=="mixing"); + if (this->precision=="single") { #ifndef __ENABLE_FLOAT_FFTW @@ -116,174 +93,120 @@ void FFT_Bundle::initfft(int nx_in, } } -void FFT_Bundle::initfftmode(int fft_mode_in) -{ - this->fft_mode = fft_mode_in; -} void FFT_Bundle::setupFFT() { - if (double_flag) - { - fft_double->setupFFT(); - } - if (float_flag) - { - fft_float->setupFFT(); - } + if (double_flag){fft_double->setupFFT();} + if (float_flag) {fft_float->setupFFT();} } void FFT_Bundle::clearFFT() { - if (double_flag) - { - fft_double->cleanFFT(); - } - if (float_flag) - { - fft_float->cleanFFT(); - } + if (double_flag){fft_double->cleanFFT();} + if (float_flag) {fft_float->cleanFFT();} } void FFT_Bundle::clear() { this->clearFFT(); - if (float_flag) - { - fft_float->clear(); - } - if (double_flag) - { - fft_double->clear(); - } -} - - -template <> -void FFT_Bundle::fftxyfor(std::complex* in, - std::complex* out) const -{ - fft_float->fftxyfor(in,out); -} - -template <> -void FFT_Bundle::fftxyfor(std::complex* in, - std::complex* out) const -{ - fft_double->fftxyfor(in,out); -} - -template <> -void FFT_Bundle::fftzfor(std::complex* in, - std::complex* out) const -{ - fft_float->fftzfor(in,out); -} -template <> -void FFT_Bundle::fftzfor(std::complex* in, - std::complex* out) const -{ - fft_double->fftzfor(in,out); -} - -template <> -void FFT_Bundle::fftxybac(std::complex* in, - std::complex* out) const -{ - fft_float->fftxybac(in,out); -} -template <> -void FFT_Bundle::fftxybac(std::complex* in, - std::complex* out) const -{ - fft_double->fftxybac(in,out); -} - -template <> -void FFT_Bundle::fftzbac(std::complex* in, - std::complex* out) const -{ - fft_float->fftzbac(in,out); -} -template <> -void FFT_Bundle::fftzbac(std::complex* in, - std::complex* out) const -{ - fft_double->fftzbac(in,out); -} -template <> -void FFT_Bundle::fftxyr2c(float* in, - std::complex* out) const -{ - fft_float->fftxyr2c(in,out); -} -template <> -void FFT_Bundle::fftxyr2c(double* in, - std::complex* out) const -{ - fft_double->fftxyr2c(in,out); -} - -template <> -void FFT_Bundle::fftxyc2r(std::complex* in, - float* out) const -{ - fft_float->fftxyc2r(in,out); -} -template <> -void FFT_Bundle::fftxyc2r(std::complex* in, - double* out) const -{ - fft_double->fftxyc2r(in,out); -} - -template <> -void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, + if (double_flag){fft_double->clear();} + if (float_flag) {fft_float->clear();} +} + +template <> void +FFT_Bundle::fftxyfor(std::complex* in, + std::complex* out) +const {fft_float->fftxyfor(in,out);} +template <> void +FFT_Bundle::fftxyfor(std::complex* in, + std::complex* out) +const {fft_double->fftxyfor(in,out);} + + +template <> void +FFT_Bundle::fftzfor(std::complex* in, + std::complex* out) +const {fft_float->fftzfor(in,out);} +template <> void +FFT_Bundle::fftzfor(std::complex* in, + std::complex* out) +const {fft_double->fftzfor(in,out);} + +template <> void +FFT_Bundle::fftxybac(std::complex* in, + std::complex* out) +const {fft_float->fftxybac(in,out);} +template <> void +FFT_Bundle::fftxybac(std::complex* in, + std::complex* out) +const {fft_double->fftxybac(in,out);} + +template <> void +FFT_Bundle::fftzbac(std::complex* in, + std::complex* out) +const {fft_float->fftzbac(in,out);} +template <> void +FFT_Bundle::fftzbac(std::complex* in, + std::complex* out) +const {fft_double->fftzbac(in,out);} + +template <> void +FFT_Bundle::fftxyr2c(float* in, + std::complex* out) +const {fft_float->fftxyr2c(in,out);} +template <> void +FFT_Bundle::fftxyr2c(double* in, + std::complex* out) +const {fft_double->fftxyr2c(in,out);} + +template <> void +FFT_Bundle::fftxyc2r(std::complex* in, + float* out) +const {fft_float->fftxyc2r(in,out);} +template <> void +FFT_Bundle::fftxyc2r(std::complex* in, + double* out) +const {fft_double->fftxyc2r(in,out);} + +template <> void +FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex* in, - std::complex* out) const -{ - fft_float->fft3D_forward(in, out); -} - -template <> -void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, - std::complex* in, - std::complex* out) const -{ - fft_double->fft3D_forward(in, out); -} -template <> -void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, + std::complex* out) +const {fft_float->fft3D_forward(in, out);} +template <> void +FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) +const {fft_double->fft3D_forward(in, out);} + +template <> void +FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, - std::complex* out) const -{ - fft_float->fft3D_backward(in, out); -} -template <> -void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, + std::complex* out) +const {fft_float->fft3D_backward(in, out);} +template <> void +FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex* in, - std::complex* out) const -{ - fft_double->fft3D_backward(in, out); -} + std::complex* out) +const {fft_double->fft3D_backward(in, out);} // access the real space data template <> float* -FFT_Bundle::get_rspace_data() const {return fft_float->get_rspace_data();} +FFT_Bundle::get_rspace_data() const {return fft_float->get_rspace_data();} template <> double* -FFT_Bundle::get_rspace_data() const {return fft_double->get_rspace_data();} +FFT_Bundle::get_rspace_data() const {return fft_double->get_rspace_data();} template <> std::complex* -FFT_Bundle::get_auxr_data() const {return fft_float->get_auxr_data();} +FFT_Bundle::get_auxr_data() const {return fft_float->get_auxr_data();} template <> std::complex* -FFT_Bundle::get_auxr_data() const{return fft_double->get_auxr_data();} +FFT_Bundle::get_auxr_data() const {return fft_double->get_auxr_data();} template <> std::complex* -FFT_Bundle::get_auxg_data() const{return fft_float->get_auxg_data();} +FFT_Bundle::get_auxg_data() const {return fft_float->get_auxg_data();} template <> std::complex* -FFT_Bundle::get_auxg_data() const{return fft_double->get_auxg_data();} +FFT_Bundle::get_auxg_data() const {return fft_double->get_auxg_data();} template <> std::complex* -FFT_Bundle::get_auxr_3d_data() const{return fft_float->get_auxr_3d_data();} +FFT_Bundle::get_auxr_3d_data() const {return fft_float->get_auxr_3d_data();} template <> std::complex* FFT_Bundle::get_auxr_3d_data() const {return fft_double->get_auxr_3d_data();} } \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index ebeb04f5b8..8af2f4ce8d 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -8,7 +8,8 @@ namespace ModulePW class FFT_Bundle { public: - FFT_Bundle(); + FFT_Bundle(){}; + ~FFT_Bundle(){}; /** * @brief Constructor with device and precision. * @param device_in device type, cpu or gpu. @@ -17,9 +18,8 @@ class FFT_Bundle * the function will check the input device and precision, * and set the device and precision. */ - FFT_Bundle(std::string device_in,std::string precision_in); - ~FFT_Bundle(); - + FFT_Bundle(std::string device_in,std::string precision_in):device(device_in),precision(precision_in){}; + /** * @brief Set device and precision. * @param device_in device type, cpu or gpu. @@ -68,7 +68,7 @@ class FFT_Bundle * the function will initialize the fft mode. */ - void initfftmode(int fft_mode_in); + void initfftmode(int fft_mode_in){this->fft_mode = fft_mode_in;} void setupFFT(); @@ -197,12 +197,8 @@ class FFT_Bundle std::complex* in, std::complex* out) const; - void set_device(std::string device_in); - - void set_precision(std::string precision_in); - private: - int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive + int fft_mode = 0; bool float_flag=false; bool float_define=true; bool double_flag=false; diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.cpp b/source/module_basis/module_pw/module_fft/fft_cpu.cpp index f06a235a2b..e8ad959f97 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu.cpp @@ -319,7 +319,6 @@ void FFT_CPU::cleanFFT() clearfft(planyc2r); } - template <> void FFT_CPU::clear() { diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.h b/source/module_basis/module_pw/module_fft/fft_cpu.h index 888880b386..27c7e862a2 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.h +++ b/source/module_basis/module_pw/module_fft/fft_cpu.h @@ -14,7 +14,7 @@ class FFT_CPU : public FFT_BASE { public: FFT_CPU(){}; - FFT_CPU(const int fft_mode_in){this->fft_mode = fft_mode_in;}; + FFT_CPU(const int fft_mode_in):fft_mode(fft_mode_in){}; ~FFT_CPU(){}; /** @@ -66,10 +66,10 @@ class FFT_CPU : public FFT_BASE FPTYPE* get_rspace_data() const override; __attribute__((weak)) - std::complex* get_auxr_data() const; + std::complex* get_auxr_data() const override; __attribute__((weak)) - std::complex* get_auxg_data() const; + std::complex* get_auxg_data() const override; /** * @brief Forward FFT in x-y direction From 479b17887b84a191b75463c97fed7423bcd99a48 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Mon, 11 Nov 2024 19:59:16 +0800 Subject: [PATCH 26/27] update the shared_ptr --- .../module_pw/module_fft/fft_bundle.h | 7 +++--- .../module_pw/module_fft/fft_cpu.cpp | 11 +++++++--- .../module_pw/module_fft/fft_cpu_float.cpp | 22 ++++++------------- .../test/charge_extra_test.cpp | 6 ----- .../test/elecstate_base_test.cpp | 7 +----- source/module_hsolver/test/hsolver_pw_sup.h | 2 -- 6 files changed, 20 insertions(+), 35 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 8af2f4ce8d..8321badb4b 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -18,7 +18,8 @@ class FFT_Bundle * the function will check the input device and precision, * and set the device and precision. */ - FFT_Bundle(std::string device_in,std::string precision_in):device(device_in),precision(precision_in){}; + FFT_Bundle(std::string device_in,std::string precision_in) + :device(device_in),precision(precision_in){}; /** * @brief Set device and precision. @@ -202,8 +203,8 @@ class FFT_Bundle bool float_flag=false; bool float_define=true; bool double_flag=false; - std::unique_ptr> fft_float=nullptr; - std::unique_ptr> fft_double=nullptr; + std::shared_ptr> fft_float=nullptr; + std::shared_ptr> fft_double=nullptr; std::string device = "cpu"; std::string precision = "double"; diff --git a/source/module_basis/module_pw/module_fft/fft_cpu.cpp b/source/module_basis/module_pw/module_fft/fft_cpu.cpp index e8ad959f97..be920d4ae2 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu.cpp @@ -247,7 +247,12 @@ void FFT_CPU::setupFFT() (fftw_complex*)z_auxr, embed, npy, - 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag); + 1, (fftw_complex*)z_auxr, + embed, + npy, + 1, + FFTW_FORWARD, + flag); this->planxbac2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), @@ -449,9 +454,9 @@ void FFT_CPU::fftxyc2r(std::complex *in,double *out) const template <> double* FFT_CPU::get_rspace_data() const {return d_rspace;} template <> std::complex* -FFT_CPU::get_auxr_data() const {return z_auxr;} +FFT_CPU::get_auxr_data() const {return z_auxr;} template <> std::complex* -FFT_CPU::get_auxg_data() const {return z_auxg;} +FFT_CPU::get_auxg_data() const {return z_auxg;} template FFT_CPU::FFT_CPU(); template FFT_CPU::~FFT_CPU(); diff --git a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp index 358e2dfa88..f84b45bf09 100644 --- a/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp +++ b/source/module_basis/module_pw/module_fft/fft_cpu_float.cpp @@ -311,21 +311,7 @@ void FFT_CPU::clear() s_rspace = nullptr; } -template <> -float* FFT_CPU::get_rspace_data() const -{ - return s_rspace; -} -template <> -std::complex* FFT_CPU::get_auxr_data() const -{ - return c_auxr; -} -template <> -std::complex* FFT_CPU::get_auxg_data() const -{ - return c_auxg; -} + template <> void FFT_CPU::fftxyfor(std::complex* in, std::complex* out) const { @@ -438,4 +424,10 @@ void FFT_CPU::fftxyc2r(std::complex* in, float* out) const } } } +template <> float* +FFT_CPU::get_rspace_data() const {return s_rspace;} +template <> std::complex* +FFT_CPU::get_auxr_data() const {return c_auxr;} +template <> std::complex* +FFT_CPU::get_auxg_data() const {return c_auxg;} } \ No newline at end of file diff --git a/source/module_elecstate/test/charge_extra_test.cpp b/source/module_elecstate/test/charge_extra_test.cpp index f52b034e4c..fadacdb327 100644 --- a/source/module_elecstate/test/charge_extra_test.cpp +++ b/source/module_elecstate/test/charge_extra_test.cpp @@ -70,12 +70,6 @@ FFT::FFT() FFT::~FFT() { } -FFT_Bundle::FFT_Bundle() -{ -} -FFT_Bundle::~FFT_Bundle() -{ -} void PW_Basis::initgrids(const double lat0_in, const ModuleBase::Matrix3 latvec_in, const double gridecut) { } diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index 95bb11949d..ea69f172df 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -56,12 +56,7 @@ ModulePW::FFT::FFT() ModulePW::FFT::~FFT() { } -ModulePW::FFT_Bundle::FFT_Bundle() -{ -} -ModulePW::FFT_Bundle::~FFT_Bundle() -{ -} + void ModulePW::PW_Basis::initgrids(double, ModuleBase::Matrix3, double) { } diff --git a/source/module_hsolver/test/hsolver_pw_sup.h b/source/module_hsolver/test/hsolver_pw_sup.h index 300492e5aa..c70025a2c2 100644 --- a/source/module_hsolver/test/hsolver_pw_sup.h +++ b/source/module_hsolver/test/hsolver_pw_sup.h @@ -4,8 +4,6 @@ namespace ModulePW { PW_Basis::PW_Basis(){}; PW_Basis::~PW_Basis(){}; -FFT_Bundle::FFT_Bundle(){}; -FFT_Bundle::~FFT_Bundle(){}; void PW_Basis::initgrids( const double lat0_in, // unit length (unit in bohr) const ModuleBase::Matrix3 From 00dfee58e1a3580cf2db88381040ca1cb1954362 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Mon, 11 Nov 2024 13:44:22 +0000 Subject: [PATCH 27/27] [pre-commit.ci lite] apply automatic fixes --- source/module_basis/module_pw/pw_basis_k.cpp | 44 +++++++++++++------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index 3e6e6b4893..2361404d84 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -70,7 +70,8 @@ void PW_Basis_K:: initparameters( this->kvec_d[ik] = kvec_d_in[ik]; this->kvec_c[ik] = this->kvec_d[ik] * this->G; double kmod = sqrt(this->kvec_c[ik] * this->kvec_c[ik]); - if(kmod > kmaxmod) kmaxmod = kmod; + if(kmod > kmaxmod) { kmaxmod = kmod; +} } this->gk_ecut = gk_ecut_in/this->tpiba2; this->ggecut = pow(sqrt(this->gk_ecut) + kmaxmod, 2); @@ -81,14 +82,16 @@ void PW_Basis_K:: initparameters( } this->gamma_only = gamma_only_in; - if(kmaxmod > 0) this->gamma_only = false; //if it is not the gamma point, we do not use gamma_only + if(kmaxmod > 0) { this->gamma_only = false; //if it is not the gamma point, we do not use gamma_only +} this->xprime = xprime_in; this->fftny = this->ny; this->fftnx = this->nx; if (this->gamma_only) { - if(this->xprime) this->fftnx = int(this->nx / 2) + 1; - else this->fftny = int(this->ny / 2) + 1; + if(this->xprime) { this->fftnx = int(this->nx / 2) + 1; + } else { this->fftny = int(this->ny / 2) + 1; +} } this->fftnz = this->nz; this->fftnxy = this->fftnx * this->fftny; @@ -142,7 +145,8 @@ void PW_Basis_K::setupIndGk() //get igl2isz_k and igl2ig_k - if(this->npwk_max <= 0) return; + if(this->npwk_max <= 0) { return; +} delete[] igl2isz_k; this->igl2isz_k = new int [this->nks * this->npwk_max]; delete[] igl2ig_k; this->igl2ig_k = new int [this->nks * this->npwk_max]; for (int ik = 0; ik < this->nks; ik++) @@ -199,7 +203,8 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h this->erf_ecut = erf_ecut_in; this->erf_height = erf_height_in; this->erf_sigma = erf_sigma_in; - if(this->npwk_max <= 0) return; + if(this->npwk_max <= 0) { return; +} delete[] gk2; delete[] gcar; this->gk2 = new double[this->npwk_max * this->nks]; @@ -219,9 +224,12 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h int ixy = this->is2fftixy[is]; int ix = ixy / this->fftny; int iy = ixy % this->fftny; - if (ix >= int(this->nx/2) + 1) ix -= this->nx; - if (iy >= int(this->ny/2) + 1) iy -= this->ny; - if (iz >= int(this->nz/2) + 1) iz -= this->nz; + if (ix >= int(this->nx/2) + 1) { ix -= this->nx; +} + if (iy >= int(this->ny/2) + 1) { iy -= this->ny; +} + if (iz >= int(this->nz/2) + 1) { iz -= this->nz; +} f.x = ix; f.y = iy; f.z = iz; @@ -278,9 +286,12 @@ ModuleBase::Vector3 PW_Basis_K:: cal_GplusK_cartesian(const int ik, cons int is = isz / this->nz; int ix = this->is2fftixy[is] / this->fftny; int iy = this->is2fftixy[is] % this->fftny; - if (ix >= int(this->nx/2) + 1) ix -= this->nx; - if (iy >= int(this->ny/2) + 1) iy -= this->ny; - if (iz >= int(this->nz/2) + 1) iz -= this->nz; + if (ix >= int(this->nx/2) + 1) { ix -= this->nx; +} + if (iy >= int(this->ny/2) + 1) { iy -= this->ny; +} + if (iz >= int(this->nz/2) + 1) { iz -= this->nz; +} ModuleBase::Vector3 f; f.x = ix; f.y = iy; @@ -362,7 +373,8 @@ std::vector PW_Basis_K::get_ig2ix(const int ik) const int is = isz / this->nz; int ixy = this->is2fftixy[is]; int ix = ixy / this->ny; - if (ix < (nx / 2) + 1) ix += nx; + if (ix < (nx / 2) + 1) { ix += nx; +} ig_to_ix[ig] = ix; } return ig_to_ix; @@ -379,7 +391,8 @@ std::vector PW_Basis_K::get_ig2iy(const int ik) const int is = isz / this->nz; int ixy = this->is2fftixy[is]; int iy = ixy % this->ny; - if (iy < (ny / 2) + 1) iy += ny; + if (iy < (ny / 2) + 1) { iy += ny; +} ig_to_iy[ig] = iy; } return ig_to_iy; @@ -394,7 +407,8 @@ std::vector PW_Basis_K::get_ig2iz(const int ik) const { int isz = this->igl2isz_k[ig + ik * npwk_max]; int iz = isz % this->nz; - if (iz < (nz / 2) + 1) iz += nz; + if (iz < (nz / 2) + 1) { iz += nz; +} ig_to_iz[ig] = iz; } return ig_to_iz;