From c1dac217884a492000bbabe8c82872c7efa47790 Mon Sep 17 00:00:00 2001 From: liutao <3158793232@qq.com> Date: Thu, 23 Jan 2025 09:04:20 +0000 Subject: [PATCH 01/14] set fft_dsp --- source/module_basis/module_pw/CMakeLists.txt | 5 + .../module_pw/module_fft/fft_bundle.cpp | 6 + .../module_pw/module_fft/fft_dsp.cpp | 124 ++++++++++++++++++ .../module_pw/module_fft/fft_dsp.h | 82 ++++++++++++ .../module_pw/module_fft/fft_dsp_float.cpp | 20 +++ .../module_pw/pw_transform_k_dsp.cpp | 0 6 files changed, 237 insertions(+) create mode 100644 source/module_basis/module_pw/module_fft/fft_dsp.cpp create mode 100644 source/module_basis/module_pw/module_fft/fft_dsp.h create mode 100644 source/module_basis/module_pw/module_fft/fft_dsp_float.cpp create mode 100644 source/module_basis/module_pw/pw_transform_k_dsp.cpp diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index 549e41c93c..e8541c5150 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -13,6 +13,11 @@ if (USE_ROCM) module_fft/fft_rocm.cpp ) endif() +if (USE_DSP) + list (APPEND FFT_SRC + module_fft/fft_dsp.cpp + pw_transform_k_dsp.cpp) +endif() list(APPEND objects pw_basis.cpp diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index c2718abf5d..d292eb79b2 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -65,6 +65,12 @@ void FFT_Bundle::initfft(int nx_in, if (device=="cpu") { + #if defined(__DSP) + if (float_flag==true) + ModuleBase::WARNING_QUT("device","now dsp is not support for the float type"); + fft_double=make_unique>(); + fft_double->initfft(nx_in,ny_in,nz_in); + #endif fft_float = make_unique>(this->fft_mode); fft_double = make_unique>(this->fft_mode); if (float_flag) diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.cpp b/source/module_basis/module_pw/module_fft/fft_dsp.cpp new file mode 100644 index 0000000000..95dde0a9e2 --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_dsp.cpp @@ -0,0 +1,124 @@ +// #define protected public +// #define private public +#include "fft_dsp.h" +#include +#include +#include +// #undef private +// #undef protected +namespace ModulePW +{ +template<> +void FFT_DSP::initfft(int nx_in,int ny_in,int nz_in) +{ + this->nx=nx_in; + this->ny=ny_in; + this->nz=nz_in; + cluster_id = 1; +} +template<> +void FFT_DSP::setupFFT() +{ + PROBLEM pbm_forward; + PROBLEM pbm_backward; + PLAN* ptr_plan_forward; + PLAN* ptr_plan_backward; + INT num_thread=8; + + INT size; + //open cluster id + hthread_dev_open(cluster_id); + //load mt.dat + hthread_dat_load(cluster_id, "mt_fft_device.dat"); + + thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL); + //create b_id for the barrier + b_id = hthread_barrier_create(cluster_id); + args_for[0] = b_id; + + //compute the size of and malloc thread + size = nx*ny*nz*2*sizeof(E); + forward_in = (E*)hthread_malloc((int)cluster_id, size, HT_MEM_RW); + + // //init 3d fft problem + pbm_forward.num_dim = 3; + pbm_forward.n[0] = nx; + pbm_forward.n[1] = ny; + pbm_forward.n[2] = nz; + pbm_forward.iFFT = 0; + pbm_forward.in = forward_in; + pbm_forward.out = forward_in; + + // //make ptr plan + make_plan(&pbm_forward, &ptr_plan_forward, cluster_id, num_thread); + ptr_plan_forward->in = forward_in; + ptr_plan_forward->out = forward_in; + args_for[1] = (unsigned long)ptr_plan_forward; + + //init 3d fft problem + pbm_backward.num_dim = 3; // dimensions of FFT + pbm_backward.n[0] = nx; // first dimension + pbm_backward.n[1] = ny; // second dimension + pbm_backward.n[2] = nz; // third dimension + pbm_backward.iFFT = 1; // 0 stand for forward,1 stand for backward + pbm_backward.in = forward_in; // the input data + pbm_backward.out = forward_in; // the output data + + make_plan(&pbm_backward, &ptr_plan_backward, cluster_id, num_thread); + ptr_plan_backward->in = forward_in; + ptr_plan_backward->out = forward_in; + args_back[0]=b_id; + args_back[1]=(unsigned long)ptr_plan_backward; +} + +template<> +void FFT_DSP::fft3D_forward(std::complex* in, + std::complex* out) const +{ + hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for); + hthread_group_wait(thread_id_for); +} + +template<> +void FFT_DSP::fft3D_backward(std::complex * in, + std::complex* out) const +{ + hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back); + hthread_group_wait(thread_id_for); + +} +template<> +void FFT_DSP::cleanFFT() +{ + if (ptr_plan_forward!=nullptr) + { + destroy_plan(ptr_plan_forward); + ptr_plan_forward=nullptr; + } + if (ptr_plan_backward!=nullptr) + { + destroy_plan(ptr_plan_backward); + ptr_plan_backward=nullptr; + } +} + +template<> +void FFT_DSP::clear() +{ + this->cleanFFT(); + hthread_free(forward_in); + hthread_barrier_destroy(b_id); + hthread_group_destroy(thread_id_for); +} + +template<> std::complex* +FFT_DSP::get_auxr_3d_data() const +{ + return reinterpret_cast*>(this->forward_in); +} +template FFT_DSP::FFT_DSP(); +template FFT_DSP::~FFT_DSP(); +template FFT_DSP::FFT_DSP(); +template FFT_DSP::~FFT_DSP(); + +} diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.h b/source/module_basis/module_pw/module_fft/fft_dsp.h new file mode 100644 index 0000000000..a86aa13cab --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_dsp.h @@ -0,0 +1,82 @@ +#ifndef FFT_CUDA_H +#define FFT_CUDA_H + +#include "fft_base.h" +#include +#include +#include + +#include "hthread_host.h" +#include "mtfft.h" +#include "fftw3.h" + +namespace ModulePW +{ +template +class FFT_DSP : public FFT_BASE +{ + public: + FFT_DSP(){}; + ~FFT_DSP(){}; + + void setupFFT() override; + + void clear() override; + + void cleanFFT() override; + + /** + * @brief Initialize the fft parameters + * @param nx_in number of grid points in x direction + * @param ny_in number of grid points in y direction + * @param nz_in number of grid points in z direction + * + */ + virtual __attribute__((weak)) + void initfft(int nx_in, + int ny_in, + int nz_in) override; + + /** + * @brief Get the real space data + * @return real space data + */ + virtual __attribute__((weak)) + std::complex* get_auxr_3d_data() const override; + + /** + * @brief Forward FFT in 3D + * @param in input data, complex FPTYPE + * @param out output data, complex FPTYPE + * + * This function performs the forward FFT in 3D. + */ + virtual __attribute__((weak)) + void fft3D_forward(std::complex* in, + std::complex* out) const override; + /** + * @brief Backward FFT in 3D + * @param in input data, complex FPTYPE + * @param out output data, complex FPTYPE + * + * This function performs the backward FFT in 3D. + */ + virtual __attribute__((weak)) + void fft3D_backward(std::complex* in, + std::complex* out) const override; + public: + INT cluster_id=0; + INT b_id; + INT thread_id_for=0; + PLAN* ptr_plan_forward=nullptr; + PLAN* ptr_plan_backward=nullptr; + mutable unsigned long args_for[2]; + mutable unsigned long args_back[2]; + mutable E * forward_in; + std::complex* c_auxr_3d = nullptr; // fft space + std::complex* z_auxr_3d = nullptr; // fft space + +}; +void test_fft_dsp(); +} // namespace ModulePW +#endif \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_dsp_float.cpp b/source/module_basis/module_pw/module_fft/fft_dsp_float.cpp new file mode 100644 index 0000000000..2a17bacd02 --- /dev/null +++ b/source/module_basis/module_pw/module_fft/fft_dsp_float.cpp @@ -0,0 +1,20 @@ +#include "fft_dsp.h" +namespace ModulePW +{ + +template<> +void FFT_DSP::setupFFT() +{ + +} +template<> +void FFT_DSP::clear() +{ + +} +template<> +void FFT_DSP::cleanFFT() +{ + +} +} \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp new file mode 100644 index 0000000000..e69de29bb2 From 4a2894bb23592d39aa33b0636fcf387d421085fb Mon Sep 17 00:00:00 2001 From: liutao <3158793232@qq.com> Date: Thu, 23 Jan 2025 11:59:51 +0000 Subject: [PATCH 02/14] add information in map --- CMakeLists.txt | 3 + .../module_pw/module_fft/fft_bundle.cpp | 6 +- .../module_pw/module_fft/fft_dsp.cpp | 4 - source/module_basis/module_pw/pw_basis_k.cpp | 7 +- source/module_basis/module_pw/pw_basis_k.h | 16 +++- .../module_basis/module_pw/pw_transform_k.cpp | 8 ++ .../module_pw/pw_transform_k_dsp.cpp | 82 +++++++++++++++++++ 7 files changed, 118 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7d6e74f898..92b60b69d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -256,6 +256,9 @@ endif() if (USE_DSP) target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY}) add_compile_definitions(__DSP) + target_include_directories(${ABACUS_BIN_NAME} ${DIR_MTFFT_INCLUDES}) + target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTFFT_LIBRARY}) + target_include_directories(${ABACUS_BIN_NAME} ${DIR_HTRERAD_INLCUDES}) endif() find_package(Threads REQUIRED) diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index d292eb79b2..7da4699058 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -9,7 +9,9 @@ #if defined(__ROCM) #include "fft_rocm.h" #endif - +#if defined(__DSP) +#include "fft_dsp.h" +#endif template std::unique_ptr make_unique(Args &&... args) { @@ -67,7 +69,7 @@ void FFT_Bundle::initfft(int nx_in, { #if defined(__DSP) if (float_flag==true) - ModuleBase::WARNING_QUT("device","now dsp is not support for the float type"); + ModuleBase::WARNING_QUT("device","now dsp fft is not support for the float type"); fft_double=make_unique>(); fft_double->initfft(nx_in,ny_in,nz_in); #endif diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.cpp b/source/module_basis/module_pw/module_fft/fft_dsp.cpp index 95dde0a9e2..f7eb17a4fb 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.cpp +++ b/source/module_basis/module_pw/module_fft/fft_dsp.cpp @@ -1,11 +1,7 @@ -// #define protected public -// #define private public #include "fft_dsp.h" #include #include #include -// #undef private -// #undef protected namespace ModulePW { template<> diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index 2e0f85372d..5f80191251 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -22,6 +22,9 @@ PW_Basis_K::~PW_Basis_K() delete[] igl2isz_k; delete[] igl2ig_k; delete[] gk2; +#if defined(__DSP) + delete[] ig2ixyz_k_cpu; +#endif #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { if (this->precision == "single") { @@ -357,7 +360,9 @@ void PW_Basis_K::get_ig2ixyz_k() } resmem_int_op()(ig2ixyz_k, this->npwk_max * this->nks); syncmem_int_h2d_op()(this->ig2ixyz_k, ig2ixyz_k_cpu, this->npwk_max * this->nks); - delete[] ig2ixyz_k_cpu; + #if not defined (__DSP) + delete[] this->ig2ixyz_k_cpu; + #endif } std::vector PW_Basis_K::get_ig2ix(const int ik) const diff --git a/source/module_basis/module_pw/pw_basis_k.h b/source/module_basis/module_pw/pw_basis_k.h index f5be09cfbd..8f03cc9ce4 100644 --- a/source/module_basis/module_pw/pw_basis_k.h +++ b/source/module_basis/module_pw/pw_basis_k.h @@ -87,7 +87,7 @@ class PW_Basis_K : public PW_Basis int *igl2isz_k=nullptr, * d_igl2isz_k = nullptr; //[npwk_max*nks] map (igl,ik) to (is,iz) int *igl2ig_k=nullptr;//[npwk_max*nks] map (igl,ik) to ig int *ig2ixyz_k=nullptr; ///< [npw] map ig to ixyz - + int *ig2ixyz_k_cpu = nullptr; /// [npw] map ig to ixyz,which is used in dsp fft. double *gk2=nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks] // liuyu add 2023-09-06 @@ -135,6 +135,20 @@ class PW_Basis_K : public PW_Basis const int ik, const bool add = false, const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + #if defined(__DSP) + template + void real2recip_3d(const std::complex* in, + std::complex* out, + const int ik, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) + template + void recip2real_3d(const std::complex* in, + std::complex* out, + const int ik, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + #endif template void real_to_recip(const Device* ctx, diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index e230066c8f..75f821132e 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -307,7 +307,11 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_CPU* /*dev*/, const bool add, const double factor) const { + #if defined(__DSP) + this->real2recip_3d(in,out,ik,add,factor); + #else this->real2recip(in, out, ik, add, factor); + #endif } template <> @@ -318,7 +322,11 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/, const bool add, const float factor) const { + #if defined(__DSP) + this->recip2real_3d(in,out,add,factor); + #else this->recip2real(in, out, ik, add, factor); + #endif } template <> void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/, diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp index e69de29bb2..849a4b90af 100644 --- a/source/module_basis/module_pw/pw_transform_k_dsp.cpp +++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp @@ -0,0 +1,82 @@ +#include "module_base/timer.h" +#include "module_basis/module_pw/kernels/pw_op.h" +#include "pw_basis_k.h" + +#include +#include +#include +namespace ModulePW +{ + template + void PW_Basis_K::real2recip_3d(const std::complex* in, + std::complex* out, + const int ik, + const bool add, + const FPTYPE factor) const + { + ModuleBase::timer::tick(this->classname,"real2recip_3d"); + const base_device::DEVICE_CPU* ctx; + const base_device::DEVICE_GPU* gpux; + assert(this->gamma_only == false); + auto* auxr = this->fft_bundle.get_auxr_3d_data(); + + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + memcpy(auxr,in,this->nrxx*2*8); + this->fft_bundle.fft3D_forward(gpux, + auxr, + auxr); + set_real_to_recip_output_op()(ctx, + npw_k, + this->nxyz, + add, + factor, + this->ig2ixyz_k_cpu + startig, + this->fft_bundle.get_auxr_3d_data(), + out); + ModuleBase::timer::tick(this->classname,"real2recip_3d"); + } + + template + void PW_Basis_K::recip2real_3d(const std::complex* in, + std::complex* out, + const int ik, + const bool add, + const FPTYPE factor) const + { + ModuleBase::timer::tick(this->classname,"recip2real_3d"); + + assert(this->gamma_only == false); + const base_device::DEVICE_CPU* ctx; + const base_device::DEVICE_GPU* gpux; + auto* auxr = this->fft_bundle.get_auxr_3d_data(); + memset(auxr,0,this->nrxx*2*8); + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + + set_3d_fft_box_op()(ctx, + npw_k, + this->ig2ixyz_k_cpu + startig, + in, + auxr); + this->fft_bundle.fft3D_backward(gpux,auxr,auxr); + set_recip_to_real_output_op()(ctx, + this->nrxx, + add, + factor, + auxr, + out); + ModuleBase::timer::tick(this->classname,"recip2real_3d"); + } + +template void PW_Basis_K::real2recip_3d(const std::complex* in, + std::complex* out, + const int ik, + const bool add, + const double factor) const; // in:(nplane,nx*ny) ; out(nz, ns) +template void PW_Basis_K::recip2real_3d(const std::complex* in, + std::complex* out, + const int ik, + const bool add, + const double factor) const; // in:(nz, ns) ; out(nplane,nx*ny) +} \ No newline at end of file From 51756fd74000e54e4df5979c72186df98c707f7b Mon Sep 17 00:00:00 2001 From: liutao <3158793232@qq.com> Date: Thu, 23 Jan 2025 12:16:24 +0000 Subject: [PATCH 03/14] update Global_rank --- source/module_basis/module_pw/module_fft/fft_dsp.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.cpp b/source/module_basis/module_pw/module_fft/fft_dsp.cpp index f7eb17a4fb..d39d8822b6 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.cpp +++ b/source/module_basis/module_pw/module_fft/fft_dsp.cpp @@ -1,4 +1,5 @@ #include "fft_dsp.h" +#include "module_base/global_variable.h" #include #include #include @@ -10,7 +11,7 @@ void FFT_DSP::initfft(int nx_in,int ny_in,int nz_in) this->nx=nx_in; this->ny=ny_in; this->nz=nz_in; - cluster_id = 1; + cluster_id = GlobalV::MY_RANK; } template<> void FFT_DSP::setupFFT() From 95490928b8697f8ee6aabf4ad0d2edf72589a296 Mon Sep 17 00:00:00 2001 From: liutao <3158793232@qq.com> Date: Thu, 23 Jan 2025 12:35:19 +0000 Subject: [PATCH 04/14] update control flow --- source/module_basis/module_pw/module_fft/fft_bundle.cpp | 3 ++- source/module_basis/module_pw/module_fft/fft_dsp.cpp | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 7da4699058..7c064ef5eb 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -72,7 +72,7 @@ void FFT_Bundle::initfft(int nx_in, ModuleBase::WARNING_QUT("device","now dsp fft is not support for the float type"); fft_double=make_unique>(); fft_double->initfft(nx_in,ny_in,nz_in); - #endif + #else fft_float = make_unique>(this->fft_mode); fft_double = make_unique>(this->fft_mode); if (float_flag) @@ -101,6 +101,7 @@ void FFT_Bundle::initfft(int nx_in, gamma_only_in, xprime_in); } + #endif } if (device=="gpu") { diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.cpp b/source/module_basis/module_pw/module_fft/fft_dsp.cpp index d39d8822b6..9aaf40c2f6 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.cpp +++ b/source/module_basis/module_pw/module_fft/fft_dsp.cpp @@ -21,7 +21,6 @@ void FFT_DSP::setupFFT() PLAN* ptr_plan_forward; PLAN* ptr_plan_backward; INT num_thread=8; - INT size; //open cluster id hthread_dev_open(cluster_id); From 5d8dfe0400f09adb01fae2e06f046af21b9a5fe0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:14:42 +0000 Subject: [PATCH 05/14] [pre-commit.ci lite] apply automatic fixes --- source/module_basis/module_pw/module_fft/fft_dsp.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.h b/source/module_basis/module_pw/module_fft/fft_dsp.h index a86aa13cab..9488c083ce 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.h +++ b/source/module_basis/module_pw/module_fft/fft_dsp.h @@ -2,9 +2,9 @@ #define FFT_CUDA_H #include "fft_base.h" -#include -#include -#include +#include +#include +#include #include "hthread_host.h" #include "mtfft.h" From e5d9214751700c8e42838f7805bf40521a129548 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Mon, 24 Feb 2025 11:47:07 +0800 Subject: [PATCH 06/14] add the fft_dsp in the fft_bundle --- CMakeLists.txt | 16 +- source/CMakeLists.txt | 7 + source/module_base/CMakeLists.txt | 5 +- .../module_base/kernels/dsp/dsp_connector.cpp | 201 ++++++++++++++++++ source/module_basis/module_pw/CMakeLists.txt | 5 + .../module_pw/module_fft/fft_bundle.cpp | 20 +- .../module_pw/module_fft/fft_dsp.cpp | 54 ++--- .../module_pw/module_fft/fft_dsp.h | 11 +- source/module_basis/module_pw/pw_basis_k.cpp | 6 +- source/module_basis/module_pw/pw_basis_k.h | 31 ++- .../module_basis/module_pw/pw_transform_k.cpp | 16 +- .../module_pw/pw_transform_k_dsp.cpp | 177 +++++++++++---- 12 files changed, 436 insertions(+), 113 deletions(-) create mode 100644 source/module_base/kernels/dsp/dsp_connector.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 92b60b69d9..7b4966409e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -252,13 +252,17 @@ if(ENABLE_MPI) add_compile_definitions(__MPI) list(APPEND math_libs MPI::MPI_CXX) endif() +target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR}) + if (USE_DSP) - target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY}) add_compile_definitions(__DSP) - target_include_directories(${ABACUS_BIN_NAME} ${DIR_MTFFT_INCLUDES}) - target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTFFT_LIBRARY}) - target_include_directories(${ABACUS_BIN_NAME} ${DIR_HTRERAD_INLCUDES}) + target_link_libraries(${ABACUS_BIN_NAME} ${MTBLAS_FFT_LIBRARY_DIR}) + target_link_libraries(${ABACUS_BIN_NAME} ${OMPI_LIBRARY1}) + include_directories(${MTBLAS_FFT_LIBRARY_DIR}/include) + include_directories(${MT_HOST_DIR}/include) + target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a) + target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a) endif() find_package(Threads REQUIRED) @@ -429,10 +433,10 @@ else() find_package(Lapack REQUIRED) include_directories(${FFTW3_INCLUDE_DIRS}) list(APPEND math_libs FFTW3::FFTW3 LAPACK::LAPACK BLAS::BLAS) - + if (ENBALE_LCAO) find_package(ScaLAPACK REQUIRED) list(APPEND math_libs ScaLAPACK::ScaLAPACK) - + endif() if(USE_OPENMP) list(APPEND math_libs FFTW3::FFTW3_OMP) endif() diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 769138b096..5acce9103f 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -104,6 +104,13 @@ if(USE_ROCM) ) endif() +if(USE_DSP) + list(APPEND device_srcs + module_base/kernels/dsp/dsp_connector.cpp + ) +endif() + + add_library(device OBJECT ${device_srcs}) if(USE_CUDA) diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index ecbdedcf6a..9f26791b25 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -65,7 +65,10 @@ add_library( ) target_link_libraries(base PUBLIC container) - +if (USE_DSP) + target_link_libraries(base PUBLIC ${MTBLAS_FFT_LIBRARY_DIR}/lib/libmtblas.a) + target_link_libraries(base PUBLIC ${MTBLAS_FFT_LIBRARY_DIR}/lib/libmtblasdev.a) +endif() add_subdirectory(module_container) if(ENABLE_COVERAGE) diff --git a/source/module_base/kernels/dsp/dsp_connector.cpp b/source/module_base/kernels/dsp/dsp_connector.cpp new file mode 100644 index 0000000000..5b62847ac6 --- /dev/null +++ b/source/module_base/kernels/dsp/dsp_connector.cpp @@ -0,0 +1,201 @@ +#include "dsp_connector.h" +#include +#include + +extern "C" +{ + #define complex_double ignore_complex_double + #include // MTBLAS_TRANSPOSE etc + #undef complex_double + #include // gemm +} + +void dspInitHandle(int id){ + mt_blas_init(id); + std::cout << " ** DSP inited on cluster "<< id << " **" << std::endl; +} // Use this at the beginning of the program to start a dsp cluster + +void dspDestoryHandle(int id){ + hthread_dev_close(id); + std::cout << " ** DSP closed on cluster "<< id << " **" << std::endl; +} // Close dsp cluster at the end + + +MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) { + switch (blasTrans[0]) { + case 'N': + case 'n': + return MtblasNoTrans; + case 'T': + case 't': + return MtblasTrans; + case 'C': + case 'c': + return MtblasConjTrans; + default: + std::cout << "Invalid BLAS transpose parameter!! Use default instead." << std::endl; + return MtblasNoTrans; + } +} // Used to convert normal transpost char to mtblas transpose flag + + +void* malloc_ht(size_t bytes, int cluster_id) +{ + //std::cout << "MALLOC " << cluster_id; + void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW); + //std::cout << ptr << " SUCCEED" << std::endl;; + return ptr; +} + +// Used to replace original malloc + +void free_ht(void* ptr) +{ + //std::cout << "FREE " << ptr; + hthread_free(ptr); + //std::cout << " FREE SUCCEED" << std::endl; +} + +// Used to replace original free + +void sgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const float *alpha, const float *a, const int *lda, + const float *b, const int *ldb, const float *beta, + float *c, const int *ldc, int cluster_id) +{ + mtblas_sgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa),convertBLASTranspose(transb), + *m,*n,*k, + *alpha, a, *lda, + b, *ldb, *beta, + c, *ldc, cluster_id + ); +} // zgemm that needn't malloc_ht or free_ht + +void dgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const double *alpha, const double *a, const int *lda, + const double *b, const int *ldb, const double *beta, + double *c, const int *ldc, int cluster_id) +{ + mtblas_dgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa),convertBLASTranspose(transb), + *m,*n,*k, + *alpha, a, *lda, + b, *ldb, *beta, + c, *ldc, cluster_id + ); +} // cgemm that needn't malloc_ht or free_ht + +void zgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc, int cluster_id) +{ + mtblas_zgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa),convertBLASTranspose(transb), + *m,*n,*k, + (const void*)alpha, (const void*)a, *lda, + (const void*)b, *ldb, (const void*)beta, + (void*)c, *ldc, cluster_id + ); +} // zgemm that needn't malloc_ht or free_ht + +void cgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc, int cluster_id) +{ + mtblas_cgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa),convertBLASTranspose(transb), + *m,*n,*k, + (const void*)alpha, (const void*)a, *lda, + (const void*)b, *ldb, (const void*)beta, + (void*)c, *ldc, cluster_id + ); +} // cgemm that needn't malloc_ht or free_ht + +// Used to replace original free + +void sgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const float *alpha, const float *a, const int *lda, + const float *b, const int *ldb, const float *beta, + float *c, const int *ldc, int cluster_id) +{ + mt_hthread_sgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa),convertBLASTranspose(transb), + *m,*n,*k, + *alpha, a, *lda, + b, *ldb, *beta, + c, *ldc, cluster_id + ); +} // zgemm that needn't malloc_ht or free_ht + +void dgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const double *alpha, const double *a, const int *lda, + const double *b, const int *ldb, const double *beta, + double *c, const int *ldc, int cluster_id) +{ + mt_hthread_dgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa),convertBLASTranspose(transb), + *m,*n,*k, + *alpha, a, *lda, + b, *ldb, *beta, + c, *ldc, cluster_id + ); +} // cgemm that needn't malloc_ht or free_ht + +void zgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, + const std::complex *a, + const int *lda, + const std::complex *b, + const int *ldb, + const std::complex *beta, + std::complex *c, + const int *ldc, + int cluster_id) +{ + std::complex* alp = (std::complex*) malloc_ht(sizeof(std::complex), cluster_id); + *alp = *alpha; + std::complex* bet = (std::complex*) malloc_ht(sizeof(std::complex), cluster_id); + *bet = *beta; + mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa),convertBLASTranspose(transb), + *m,*n,*k, + alp, a, *lda, + b, *ldb, bet, + c, *ldc, cluster_id + ); + + +} // zgemm that needn't malloc_ht or free_ht + +void cgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc, int cluster_id) +{ + std::complex* alp = (std::complex*) malloc_ht(sizeof(std::complex), cluster_id); + *alp = *alpha; + std::complex* bet = (std::complex*) malloc_ht(sizeof(std::complex), cluster_id); + *bet = *beta; + + mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa),convertBLASTranspose(transb), + *m,*n,*k, + (const void*)alp, (const void*)a, *lda, + (const void*)b, *ldb, (const void*)bet, + (void*)c, *ldc, cluster_id + ); + + free_ht(alp); + free_ht(bet); +} // cgemm that needn't malloc_ht or free_ht \ No newline at end of file diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index e8541c5150..54bd56c3d9 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -16,6 +16,7 @@ endif() if (USE_DSP) list (APPEND FFT_SRC module_fft/fft_dsp.cpp + module_fft/fft_dsp_float.cpp pw_transform_k_dsp.cpp) endif() @@ -41,6 +42,10 @@ add_library( ${objects} ) +if (USE_DSP) +target_link_libraries(planewave PRIVATE +${MTBLAS_FFT_LIBRARY_DIR}/lib/libmtfft.a) +endif() if(ENABLE_COVERAGE) add_coverage(planewave) endif() diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 7c064ef5eb..a35005e48f 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -3,6 +3,7 @@ #include "module_base/module_device/device.h" #include "module_base/module_device/memory_op.h" +#include "module_base/tool_quit.h" #if defined(__CUDA) #include "fft_cuda.h" #endif @@ -42,7 +43,7 @@ void FFT_Bundle::initfft(int nx_in, bool xprime_in , bool mpifft_in) { - assert(this->device=="cpu" || this->device=="gpu"); + assert(this->device=="cpu" || this->device=="gpu" || this->device=="dsp"); assert(this->precision=="single" || this->precision=="double" || this->precision=="mixing"); if (this->precision=="single") @@ -64,15 +65,17 @@ void FFT_Bundle::initfft(int nx_in, { double_flag = true; } - + #if defined(__DSP) + if (device=="dsp") + { + if (float_flag) + ModuleBase::WARNING_QUIT("device","now dsp fft is not support for the float type"); + fft_double=make_unique>(); + fft_double->initfft(nx_in,ny_in,nz_in); + } + #endif if (device=="cpu") { - #if defined(__DSP) - if (float_flag==true) - ModuleBase::WARNING_QUT("device","now dsp fft is not support for the float type"); - fft_double=make_unique>(); - fft_double->initfft(nx_in,ny_in,nz_in); - #else fft_float = make_unique>(this->fft_mode); fft_double = make_unique>(this->fft_mode); if (float_flag) @@ -101,7 +104,6 @@ void FFT_Bundle::initfft(int nx_in, gamma_only_in, xprime_in); } - #endif } if (device=="gpu") { diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.cpp b/source/module_basis/module_pw/module_fft/fft_dsp.cpp index 9aaf40c2f6..f705854372 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.cpp +++ b/source/module_basis/module_pw/module_fft/fft_dsp.cpp @@ -1,8 +1,8 @@ #include "fft_dsp.h" -#include "module_base/global_variable.h" #include #include #include +#include "module_base/global_variable.h" namespace ModulePW { template<> @@ -12,6 +12,7 @@ void FFT_DSP::initfft(int nx_in,int ny_in,int nz_in) this->ny=ny_in; this->nz=nz_in; cluster_id = GlobalV::MY_RANK; + nxyz=this->nx*this->ny*this->nz; } template<> void FFT_DSP::setupFFT() @@ -22,21 +23,14 @@ void FFT_DSP::setupFFT() PLAN* ptr_plan_backward; INT num_thread=8; INT size; - //open cluster id - hthread_dev_open(cluster_id); - //load mt.dat - hthread_dat_load(cluster_id, "mt_fft_device.dat"); - - thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL); - //create b_id for the barrier - b_id = hthread_barrier_create(cluster_id); - args_for[0] = b_id; + hthread_dat_load(cluster_id, "/vol8/home/dptech_zyz1/develop/blasfft/mtfftblas/datfile/mt_fft_blas.dat"); + //compute the size of and malloc thread size = nx*ny*nz*2*sizeof(E); forward_in = (E*)hthread_malloc((int)cluster_id, size, HT_MEM_RW); - - // //init 3d fft problem + +// // //init 3d fft problem pbm_forward.num_dim = 3; pbm_forward.n[0] = nx; pbm_forward.n[1] = ny; @@ -45,25 +39,24 @@ void FFT_DSP::setupFFT() pbm_forward.in = forward_in; pbm_forward.out = forward_in; - // //make ptr plan +// // //make ptr plan make_plan(&pbm_forward, &ptr_plan_forward, cluster_id, num_thread); ptr_plan_forward->in = forward_in; ptr_plan_forward->out = forward_in; args_for[1] = (unsigned long)ptr_plan_forward; //init 3d fft problem - pbm_backward.num_dim = 3; // dimensions of FFT - pbm_backward.n[0] = nx; // first dimension - pbm_backward.n[1] = ny; // second dimension - pbm_backward.n[2] = nz; // third dimension - pbm_backward.iFFT = 1; // 0 stand for forward,1 stand for backward - pbm_backward.in = forward_in; // the input data - pbm_backward.out = forward_in; // the output data + pbm_backward.num_dim = 3; + pbm_backward.n[0] = nx; + pbm_backward.n[1] = ny; + pbm_backward.n[2] = nz; + pbm_backward.iFFT = 1; + pbm_backward.in = forward_in; + pbm_backward.out = forward_in; make_plan(&pbm_backward, &ptr_plan_backward, cluster_id, num_thread); ptr_plan_backward->in = forward_in; ptr_plan_backward->out = forward_in; - args_back[0]=b_id; args_back[1]=(unsigned long)ptr_plan_backward; } @@ -71,16 +64,30 @@ template<> void FFT_DSP::fft3D_forward(std::complex* in, std::complex* out) const { + INT num_thread=8; + thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL); + //create b_id for the barrier + b_id = hthread_barrier_create(cluster_id); + args_for[0] = b_id; hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for); hthread_group_wait(thread_id_for); + hthread_barrier_destroy(b_id); + hthread_group_destroy(thread_id_for); } template<> void FFT_DSP::fft3D_backward(std::complex * in, std::complex* out) const { + INT num_thread=8; + thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL); + //create b_id for the barrier + b_id = hthread_barrier_create(cluster_id); + args_back[0] =b_id; hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back); hthread_group_wait(thread_id_for); + hthread_barrier_destroy(b_id); + hthread_group_destroy(thread_id_for); } template<> @@ -103,8 +110,6 @@ void FFT_DSP::clear() { this->cleanFFT(); hthread_free(forward_in); - hthread_barrier_destroy(b_id); - hthread_group_destroy(thread_id_for); } template<> std::complex* @@ -116,5 +121,4 @@ template FFT_DSP::FFT_DSP(); template FFT_DSP::~FFT_DSP(); template FFT_DSP::FFT_DSP(); template FFT_DSP::~FFT_DSP(); - -} +} \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.h b/source/module_basis/module_pw/module_fft/fft_dsp.h index 9488c083ce..6cadef975f 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.h +++ b/source/module_basis/module_pw/module_fft/fft_dsp.h @@ -1,5 +1,5 @@ -#ifndef FFT_CUDA_H -#define FFT_CUDA_H +#ifndef FFT_DSP_H +#define FFT_DSP_H #include "fft_base.h" #include @@ -65,18 +65,19 @@ class FFT_DSP : public FFT_BASE void fft3D_backward(std::complex* in, std::complex* out) const override; public: + int nxyz; INT cluster_id=0; - INT b_id; - INT thread_id_for=0; + mutable INT b_id; + mutable INT thread_id_for=0; PLAN* ptr_plan_forward=nullptr; PLAN* ptr_plan_backward=nullptr; mutable unsigned long args_for[2]; mutable unsigned long args_back[2]; mutable E * forward_in; + mutable E * convert2; std::complex* c_auxr_3d = nullptr; // fft space std::complex* z_auxr_3d = nullptr; // fft space }; -void test_fft_dsp(); } // namespace ModulePW #endif \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index 5f80191251..9bb33ba8f8 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -147,8 +147,8 @@ void PW_Basis_K::setupIndGk() //get igl2isz_k and igl2ig_k - if(this->npwk_max <= 0) { return; -} + if(this->npwk_max <= 0) { return;} + delete[] igl2isz_k; this->igl2isz_k = new int [this->nks * this->npwk_max]; delete[] igl2ig_k; this->igl2ig_k = new int [this->nks * this->npwk_max]; for (int ik = 0; ik < this->nks; ik++) @@ -188,7 +188,7 @@ void PW_Basis_K::setuptransform() this->getstartgr(); this->setupIndGk(); this->fft_bundle.clear(); - this->fft_bundle.setfft(this->device,this->precision); + this->fft_bundle.setfft("dsp",this->precision); if(this->xprime){ this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); }else{ diff --git a/source/module_basis/module_pw/pw_basis_k.h b/source/module_basis/module_pw/pw_basis_k.h index 8f03cc9ce4..478be2e7ae 100644 --- a/source/module_basis/module_pw/pw_basis_k.h +++ b/source/module_basis/module_pw/pw_basis_k.h @@ -136,18 +136,29 @@ class PW_Basis_K : public PW_Basis const bool add = false, const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) #if defined(__DSP) + template + void convolution(const Device* ctx, + const int ik, + const int size, + const std::complex* input, + const FPTYPE* input1, + std::complex* output, + const bool add = false, + const FPTYPE factor =1.0) const ; + template - void real2recip_3d(const std::complex* in, - std::complex* out, - const int ik, - const bool add = false, - const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) + void real2recip_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) template - void recip2real_3d(const std::complex* in, - std::complex* out, - const int ik, - const bool add = false, - const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + void recip2real_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + #endif template diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index 75f821132e..1b8b6e3f76 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -308,9 +308,10 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_CPU* /*dev*/, const double factor) const { #if defined(__DSP) - this->real2recip_3d(in,out,ik,add,factor); + printf("beforce the real_to_recip\n"); + this->real2recip_dsp(in,out,ik,add,factor); #else - this->real2recip(in, out, ik, add, factor); + this->real2recip(in, out, ik, add, factor); #endif } @@ -322,11 +323,7 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/, const bool add, const float factor) const { - #if defined(__DSP) - this->recip2real_3d(in,out,add,factor); - #else this->recip2real(in, out, ik, add, factor); - #endif } template <> void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/, @@ -336,7 +333,12 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/, const bool add, const double factor) const { - this->recip2real(in, out, ik, add, factor); + #if defined(__DSP) + printf("beforce the recip_to_real\n"); + this->recip2real_dsp(in,out,ik,add,factor); + #else + this->recip2real(in, out, ik, add, factor); + #endif } #if (defined(__CUDA) || defined(__ROCM)) diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp index 849a4b90af..f6de9d17a2 100644 --- a/source/module_basis/module_pw/pw_transform_k_dsp.cpp +++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp @@ -1,82 +1,165 @@ #include "module_base/timer.h" #include "module_basis/module_pw/kernels/pw_op.h" #include "pw_basis_k.h" - +#include "pw_gatherscatter.h" #include #include -#include + namespace ModulePW { template - void PW_Basis_K::real2recip_3d(const std::complex* in, - std::complex* out, - const int ik, - const bool add, - const FPTYPE factor) const + void PW_Basis_K::real2recip_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add , + const FPTYPE factor ) const { - ModuleBase::timer::tick(this->classname,"real2recip_3d"); const base_device::DEVICE_CPU* ctx; const base_device::DEVICE_GPU* gpux; - assert(this->gamma_only == false); + assert(this->gamma_only==false); auto* auxr = this->fft_bundle.get_auxr_3d_data(); const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; + // copy the in into the auxr with complex memcpy(auxr,in,this->nrxx*2*8); - this->fft_bundle.fft3D_forward(gpux, - auxr, - auxr); - set_real_to_recip_output_op()(ctx, - npw_k, - this->nxyz, - add, - factor, - this->ig2ixyz_k_cpu + startig, - this->fft_bundle.get_auxr_3d_data(), - out); - ModuleBase::timer::tick(this->classname,"real2recip_3d"); + + // 3d fft + this->fft_bundle.fft3D_forward(gpux, + auxr, + auxr); + + // copy the result from the auxr to the out ,while consider the add + set_real_to_recip_output_op()(ctx, + npw_k, + this->nxyz, + add, + factor, + this->ig2ixyz_k_cpu + startig, + auxr, + out); } + template + void PW_Basis_K::recip2real_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add , + const FPTYPE factor ) const + { + assert(this->gamma_only == false); + const base_device::DEVICE_CPU* ctx; + const base_device::DEVICE_GPU* gpux; + printf("beforce the recip2real_dsp\n"); + // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz + auto * auxr = this->fft_bundle.get_auxr_3d_data(); + memset(auxr,0,this->nxyz*2*8); - template - void PW_Basis_K::recip2real_3d(const std::complex* in, - std::complex* out, - const int ik, - const bool add, - const FPTYPE factor) const + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + printf("beforce the set_3d_fft_box_op\n"); + //copy the mapping form the type of stick to the 3dfft + set_3d_fft_box_op() + ( + ctx,npw_k,this->ig2ixyz_k_cpu+startig,in,auxr + ); + printf("beforce the fft3D_backward\n"); + // use 3d fft backward + this->fft_bundle.fft3D_backward(gpux,auxr,auxr); + printf("beforce the add\n"); + if(add) + { + const int one =1; + const std::complex factor1=std::complex(factor,0); + zaxpy_(&nrxx,&factor1,auxr,&one,out,&one); + } + else + { + memcpy(out,auxr,nrxx*2*8); + } + printf("after the add\n"); + } + template <> + void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, + const int ik, + const int size, + const std::complex* input, + const float* input1, + std::complex* output, + const bool add , + const float factor ) const + { + + } + + template <> + void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, + const int ik, + const int size, + const std::complex* input, + const double* input1, + std::complex* output, + const bool add , + const double factor ) const { - ModuleBase::timer::tick(this->classname,"recip2real_3d"); + ModuleBase::timer::tick(this->classname,"convolution"); assert(this->gamma_only == false); - const base_device::DEVICE_CPU* ctx; const base_device::DEVICE_GPU* gpux; - auto* auxr = this->fft_bundle.get_auxr_3d_data(); - memset(auxr,0,this->nrxx*2*8); + // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz + auto * auxr = this->fft_bundle.get_auxr_3d_data(); + memset(auxr,0,this->nxyz*2*8); const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; + + //copy the mapping form the type of stick to the 3dfft + set_3d_fft_box_op() + ( + ctx,npw_k,this->ig2ixyz_k_cpu+startig,input,auxr + ); - set_3d_fft_box_op()(ctx, - npw_k, - this->ig2ixyz_k_cpu + startig, - in, - auxr); + // use 3d fft backward this->fft_bundle.fft3D_backward(gpux,auxr,auxr); - set_recip_to_real_output_op()(ctx, - this->nrxx, - add, - factor, - auxr, - out); - ModuleBase::timer::tick(this->classname,"recip2real_3d"); + + for (int ir=0;irfft_bundle.fft3D_forward(gpux, + auxr, + auxr); + // copy the result from the auxr to the out ,while consider the add + set_real_to_recip_output_op()(ctx, + npw_k, + this->nxyz, + add, + factor, + this->ig2ixyz_k_cpu + startig, + auxr, + output); + ModuleBase::timer::tick(this->classname,"convolution"); } + +// template void PW_Basis_K::real2recip_dsp(const std::complex* in, +// std::complex* out, +// const int ik, +// const bool add, +// const float factor) const; // in:(nplane,nx*ny) ; out(nz, ns) +// template void PW_Basis_K::recip2real_dsp(const std::complex* in, +// std::complex* out, +// const int ik, +// const bool add, +// const float factor) const; // in:(nz, ns) ; out(nplane,nx*ny) -template void PW_Basis_K::real2recip_3d(const std::complex* in, +template void PW_Basis_K::real2recip_dsp(const std::complex* in, std::complex* out, const int ik, const bool add, const double factor) const; // in:(nplane,nx*ny) ; out(nz, ns) -template void PW_Basis_K::recip2real_3d(const std::complex* in, +template void PW_Basis_K::recip2real_dsp(const std::complex* in, std::complex* out, const int ik, const bool add, - const double factor) const; // in:(nz, ns) ; out(nplane,nx*ny) -} \ No newline at end of file + const double factor) const; +} From 88c25d7e86481236279e744d8ae8ee47d47f8c61 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Mon, 24 Feb 2025 11:47:38 +0800 Subject: [PATCH 07/14] change teh cmake file --- CMakeLists.txt | 3 +- source/module_base/CMakeLists.txt | 4 +-- source/module_basis/module_pw/CMakeLists.txt | 4 ++- .../module_pw/module_fft/fft_dsp.cpp | 3 +- source/module_basis/module_pw/pw_basis.cpp | 2 +- source/module_basis/module_pw/pw_basis_k.cpp | 28 +++++++++---------- source/module_basis/module_pw/pw_basis_k.h | 2 +- .../module_basis/module_pw/pw_transform_k.cpp | 2 -- .../module_pw/pw_transform_k_dsp.cpp | 13 +++------ 9 files changed, 26 insertions(+), 35 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b4966409e..c7f468687a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -257,9 +257,8 @@ target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR}) if (USE_DSP) add_compile_definitions(__DSP) - target_link_libraries(${ABACUS_BIN_NAME} ${MTBLAS_FFT_LIBRARY_DIR}) target_link_libraries(${ABACUS_BIN_NAME} ${OMPI_LIBRARY1}) - include_directories(${MTBLAS_FFT_LIBRARY_DIR}/include) + include_directories(${MTBLAS_FFT_DIR}/libmtblas/include) include_directories(${MT_HOST_DIR}/include) target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a) target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a) diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index 9f26791b25..e6b016b311 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -66,8 +66,8 @@ add_library( target_link_libraries(base PUBLIC container) if (USE_DSP) - target_link_libraries(base PUBLIC ${MTBLAS_FFT_LIBRARY_DIR}/lib/libmtblas.a) - target_link_libraries(base PUBLIC ${MTBLAS_FFT_LIBRARY_DIR}/lib/libmtblasdev.a) + target_link_libraries(base PUBLIC ${MTBLAS_FFT_DIR}/libmtblas/lib/libmtblas.a) + target_link_libraries(base PUBLIC ${MTBLAS_FFT_DIR}/libmtblas/lib/libmtblasdev.a) endif() add_subdirectory(module_container) diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index 54bd56c3d9..e365e12b5e 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -44,7 +44,9 @@ add_library( if (USE_DSP) target_link_libraries(planewave PRIVATE -${MTBLAS_FFT_LIBRARY_DIR}/lib/libmtfft.a) +${MTBLAS_FFT_DIR}/libmtblas/lib/libmtfft.a) +target_compile_definitions( planewave PUBLIC +FFT_DAT_DIR="${MTBLAS_FFT_DIR}/datfile/mt_fft_blas.dat") endif() if(ENABLE_COVERAGE) add_coverage(planewave) diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.cpp b/source/module_basis/module_pw/module_fft/fft_dsp.cpp index f705854372..4ff838cd52 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.cpp +++ b/source/module_basis/module_pw/module_fft/fft_dsp.cpp @@ -23,8 +23,7 @@ void FFT_DSP::setupFFT() PLAN* ptr_plan_backward; INT num_thread=8; INT size; - - hthread_dat_load(cluster_id, "/vol8/home/dptech_zyz1/develop/blasfft/mtfftblas/datfile/mt_fft_blas.dat"); + hthread_dat_load(cluster_id, FFT_DAT_DIR); //compute the size of and malloc thread size = nx*ny*nz*2*sizeof(E); diff --git a/source/module_basis/module_pw/pw_basis.cpp b/source/module_basis/module_pw/pw_basis.cpp index 5fbff68f0c..cc9eb37771 100644 --- a/source/module_basis/module_pw/pw_basis.cpp +++ b/source/module_basis/module_pw/pw_basis.cpp @@ -15,7 +15,7 @@ PW_Basis::PW_Basis() PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) { classname="PW_Basis"; - this->fft_bundle.setfft("cpu",this->precision); + this->fft_bundle.setfft(this->device,this->precision); } PW_Basis:: ~PW_Basis() diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index 9bb33ba8f8..b005691c1e 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -22,9 +22,6 @@ PW_Basis_K::~PW_Basis_K() delete[] igl2isz_k; delete[] igl2ig_k; delete[] gk2; -#if defined(__DSP) - delete[] ig2ixyz_k_cpu; -#endif #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { if (this->precision == "single") { @@ -148,7 +145,7 @@ void PW_Basis_K::setupIndGk() //get igl2isz_k and igl2ig_k if(this->npwk_max <= 0) { return;} - + delete[] igl2isz_k; this->igl2isz_k = new int [this->nks * this->npwk_max]; delete[] igl2ig_k; this->igl2ig_k = new int [this->nks * this->npwk_max]; for (int ik = 0; ik < this->nks; ik++) @@ -188,7 +185,11 @@ void PW_Basis_K::setuptransform() this->getstartgr(); this->setupIndGk(); this->fft_bundle.clear(); - this->fft_bundle.setfft("dsp",this->precision); + #if defined(__DSP) + this->fft_bundle.setfft("dsp",this->precision); + #else + this->fft_bundle.setfft("cpu",this->precision); + #endif if(this->xprime){ this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); }else{ @@ -337,12 +338,12 @@ int& PW_Basis_K::getigl2ig(const int ik, const int igl) const void PW_Basis_K::get_ig2ixyz_k() { - if (this->device != "gpu") - { - //only GPU need to get ig2ixyz_k - return; - } - int * ig2ixyz_k_cpu = new int [this->npwk_max * this->nks]; + // if (this->device != "gpu") + // { + // //only GPU need to get ig2ixyz_k + // return; + // } + ig2ixyz_k_cpu.resize(this->npwk_max * this->nks); ModuleBase::Memory::record("PW_B_K::ig2ixyz", sizeof(int) * this->npwk_max * this->nks); assert(gamma_only == false); //We only finish non-gamma_only fft on GPU temperarily. for(int ik = 0; ik < this->nks; ++ik) @@ -359,10 +360,7 @@ void PW_Basis_K::get_ig2ixyz_k() } } resmem_int_op()(ig2ixyz_k, this->npwk_max * this->nks); - syncmem_int_h2d_op()(this->ig2ixyz_k, ig2ixyz_k_cpu, this->npwk_max * this->nks); - #if not defined (__DSP) - delete[] this->ig2ixyz_k_cpu; - #endif + syncmem_int_h2d_op()(this->ig2ixyz_k, ig2ixyz_k_cpu.data(), this->npwk_max * this->nks); } std::vector PW_Basis_K::get_ig2ix(const int ik) const diff --git a/source/module_basis/module_pw/pw_basis_k.h b/source/module_basis/module_pw/pw_basis_k.h index 478be2e7ae..ae5076bba9 100644 --- a/source/module_basis/module_pw/pw_basis_k.h +++ b/source/module_basis/module_pw/pw_basis_k.h @@ -87,7 +87,7 @@ class PW_Basis_K : public PW_Basis int *igl2isz_k=nullptr, * d_igl2isz_k = nullptr; //[npwk_max*nks] map (igl,ik) to (is,iz) int *igl2ig_k=nullptr;//[npwk_max*nks] map (igl,ik) to ig int *ig2ixyz_k=nullptr; ///< [npw] map ig to ixyz - int *ig2ixyz_k_cpu = nullptr; /// [npw] map ig to ixyz,which is used in dsp fft. + std::vector ig2ixyz_k_cpu; /// [npw] map ig to ixyz,which is used in dsp fft. double *gk2=nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks] // liuyu add 2023-09-06 diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index 1b8b6e3f76..3d75f07f6f 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -308,7 +308,6 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_CPU* /*dev*/, const double factor) const { #if defined(__DSP) - printf("beforce the real_to_recip\n"); this->real2recip_dsp(in,out,ik,add,factor); #else this->real2recip(in, out, ik, add, factor); @@ -334,7 +333,6 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/, const double factor) const { #if defined(__DSP) - printf("beforce the recip_to_real\n"); this->recip2real_dsp(in,out,ik,add,factor); #else this->recip2real(in, out, ik, add, factor); diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp index f6de9d17a2..59485d599a 100644 --- a/source/module_basis/module_pw/pw_transform_k_dsp.cpp +++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp @@ -35,7 +35,7 @@ namespace ModulePW this->nxyz, add, factor, - this->ig2ixyz_k_cpu + startig, + this->ig2ixyz_k_cpu.data() + startig, auxr, out); } @@ -49,23 +49,19 @@ namespace ModulePW assert(this->gamma_only == false); const base_device::DEVICE_CPU* ctx; const base_device::DEVICE_GPU* gpux; - printf("beforce the recip2real_dsp\n"); // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz auto * auxr = this->fft_bundle.get_auxr_3d_data(); memset(auxr,0,this->nxyz*2*8); const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; - printf("beforce the set_3d_fft_box_op\n"); //copy the mapping form the type of stick to the 3dfft set_3d_fft_box_op() ( - ctx,npw_k,this->ig2ixyz_k_cpu+startig,in,auxr + ctx,npw_k,this->ig2ixyz_k_cpu.data()+startig,in,auxr ); - printf("beforce the fft3D_backward\n"); // use 3d fft backward this->fft_bundle.fft3D_backward(gpux,auxr,auxr); - printf("beforce the add\n"); if(add) { const int one =1; @@ -76,7 +72,6 @@ namespace ModulePW { memcpy(out,auxr,nrxx*2*8); } - printf("after the add\n"); } template <> void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, @@ -114,7 +109,7 @@ namespace ModulePW //copy the mapping form the type of stick to the 3dfft set_3d_fft_box_op() ( - ctx,npw_k,this->ig2ixyz_k_cpu+startig,input,auxr + ctx,npw_k,this->ig2ixyz_k_cpu.data()+startig,input,auxr ); // use 3d fft backward @@ -135,7 +130,7 @@ namespace ModulePW this->nxyz, add, factor, - this->ig2ixyz_k_cpu + startig, + this->ig2ixyz_k_cpu.data() + startig, auxr, output); ModuleBase::timer::tick(this->classname,"convolution"); From 18a64b6c17502274e4957f86bc490f13c95d0639 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Mon, 24 Feb 2025 11:51:56 +0800 Subject: [PATCH 08/14] modify back scalapck --- CMakeLists.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 82e9a9ded9..fb02f66809 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -255,8 +255,6 @@ if(ENABLE_MPI) add_compile_definitions(__MPI) list(APPEND math_libs MPI::MPI_CXX) endif() -target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR}) - if (USE_DSP) add_compile_definitions(__DSP) @@ -435,10 +433,8 @@ else() find_package(Lapack REQUIRED) include_directories(${FFTW3_INCLUDE_DIRS}) list(APPEND math_libs FFTW3::FFTW3 LAPACK::LAPACK BLAS::BLAS) - if (ENBALE_LCAO) find_package(ScaLAPACK REQUIRED) list(APPEND math_libs ScaLAPACK::ScaLAPACK) - endif() if(USE_OPENMP) list(APPEND math_libs FFTW3::FFTW3_OMP) endif() From 6e151a2adc22c13360a588e469149fa8dfc1453c Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Mon, 24 Feb 2025 11:57:14 +0800 Subject: [PATCH 09/14] set the dsp ig2ixyz_k_cpu --- source/module_basis/module_pw/pw_basis_k.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index b005691c1e..d9eecf8ae1 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -188,7 +188,7 @@ void PW_Basis_K::setuptransform() #if defined(__DSP) this->fft_bundle.setfft("dsp",this->precision); #else - this->fft_bundle.setfft("cpu",this->precision); + this->fft_bundle.setfft(this->device,this->precision); #endif if(this->xprime){ this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); @@ -338,11 +338,13 @@ int& PW_Basis_K::getigl2ig(const int ik, const int igl) const void PW_Basis_K::get_ig2ixyz_k() { - // if (this->device != "gpu") - // { - // //only GPU need to get ig2ixyz_k - // return; - // } + #if not defined(__DSP) + if (this->device != "gpu") + { + //only GPU need to get ig2ixyz_k + return; + } + #endif ig2ixyz_k_cpu.resize(this->npwk_max * this->nks); ModuleBase::Memory::record("PW_B_K::ig2ixyz", sizeof(int) * this->npwk_max * this->nks); assert(gamma_only == false); //We only finish non-gamma_only fft on GPU temperarily. From 813bf054eccee2a324954ae78b3d05e9e9655b38 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Mon, 24 Feb 2025 15:05:03 +0800 Subject: [PATCH 10/14] modify the pw_basis --- source/module_basis/module_pw/pw_basis.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_basis/module_pw/pw_basis.cpp b/source/module_basis/module_pw/pw_basis.cpp index cbb7f6ca8c..d5183db90c 100644 --- a/source/module_basis/module_pw/pw_basis.cpp +++ b/source/module_basis/module_pw/pw_basis.cpp @@ -15,7 +15,7 @@ PW_Basis::PW_Basis() PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) { classname="PW_Basis"; - this->fft_bundle.setfft(this->device,this->precision); + this->fft_bundle.setfft("cpu",this->precision); } PW_Basis:: ~PW_Basis() From 50a83660361505fbfad5df36fd898d3c38141794 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Mon, 24 Feb 2025 22:30:16 +0800 Subject: [PATCH 11/14] add the namespace --- source/module_base/blas_connector.cpp | 394 ++++++++++------ .../module_base/kernels/dsp/dsp_connector.cpp | 436 ++++++++++++------ .../module_base/kernels/dsp/dsp_connector.h | 218 +++++---- .../module_base/module_device/memory_op.cpp | 6 +- .../module_pw/module_fft/fft_bundle.cpp | 387 +++++++++------- .../module_pw/module_fft/fft_bundle.h | 388 ++++++++-------- .../module_pw/module_fft/fft_dsp.cpp | 111 +++-- source/module_basis/module_pw/pw_basis_k.cpp | 323 ++++++++----- .../module_pw/pw_transform_k_dsp.cpp | 245 +++++----- source/module_esolver/esolver_ks_pw.cpp | 4 +- source/module_hsolver/diago_dav_subspace.cpp | 2 +- 11 files changed, 1464 insertions(+), 1050 deletions(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index b422969ac5..5ccb7fc369 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -226,7 +226,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - sgemm_mth_(&transb, &transa, &n, &m, &k, + mtfunc::sgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -240,79 +240,136 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } } -void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, - const double alpha, const double *a, const int lda, const double *b, const int ldb, - const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type) +void BlasConnector::gemm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const double alpha, + const double* a, + const int lda, + const double* b, + const int ldb, + const double beta, + double* c, + const int ldc, + base_device::AbacusDevice_t device_type) { - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - dgemm_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); - } + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + dgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice){ - dgemm_mth_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); -#endif - } -} - -void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, - const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, - const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - cgemm_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck( + cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc)); +#endif + } +} + +void BlasConnector::gemm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + cgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) { - cgemm_mth_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (float2*)&alpha, (float2*)b, ldb, (float2*)a, lda, (float2*)&beta, (float2*)c, ldc)); -#endif - } -} - -void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, - const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, - const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zgemm_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + n, + m, + k, + (float2*)&alpha, + (float2*)b, + ldb, + (float2*)a, + lda, + (float2*)&beta, + (float2*)c, + ldc)); +#endif + } +} + +void BlasConnector::gemm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + zgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) { - zgemm_mth_(&transb, &transa, &n, &m, &k, - &alpha, b, &ldb, a, &lda, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (double2*)&alpha, (double2*)b, ldb, (double2*)a, lda, (double2*)&beta, (double2*)c, ldc)); -#endif - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + n, + m, + k, + (double2*)&alpha, + (double2*)b, + ldb, + (double2*)a, + lda, + (double2*)&beta, + (double2*)c, + ldc)); +#endif + } } // Col-Major part @@ -327,7 +384,7 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - sgemm_mth_(&transb, &transa, &m, &n, &k, + mtfunc::sgemm_mth_(&transb, &transa, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -341,79 +398,136 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c } } -void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, - const double alpha, const double *a, const int lda, const double *b, const int ldb, - const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type) +void BlasConnector::gemm_cm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const double alpha, + const double* a, + const int lda, + const double* b, + const int ldb, + const double beta, + double* c, + const int ldc, + base_device::AbacusDevice_t device_type) { - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - dgemm_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice){ - dgemm_mth_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); -#endif - } -} - -void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, - const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, - const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - cgemm_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::dgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck( + cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)); +#endif + } +} + +void BlasConnector::gemm_cm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + cgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) { - cgemm_mth_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc)); -#endif - } -} - -void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, - const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, - const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) -{ - if (device_type == base_device::AbacusDevice_t::CpuDevice) { - zgemm_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc); - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + m, + n, + k, + (float2*)&alpha, + (float2*)a, + lda, + (float2*)b, + ldb, + (float2*)&beta, + (float2*)c, + ldc)); +#endif + } +} + +void BlasConnector::gemm_cm(const char transa, + const char transb, + const int m, + const int n, + const int k, + const std::complex alpha, + const std::complex* a, + const int lda, + const std::complex* b, + const int ldb, + const std::complex beta, + std::complex* c, + const int ldc, + base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) + { + zgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + } #ifdef __DSP - else if (device_type == base_device::AbacusDevice_t::DspDevice) { - zgemm_mth_(&transa, &transb, &m, &n, &k, - &alpha, a, &lda, b, &ldb, - &beta, c, &ldc, GlobalV::MY_RANK); - } -#endif - else if (device_type == base_device::AbacusDevice_t::GpuDevice){ -#ifdef __CUDA - cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); - cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); - cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc)); -#endif - } + else if (device_type == base_device::AbacusDevice_t::DspDevice) + { + mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK); + } +#endif + else if (device_type == base_device::AbacusDevice_t::GpuDevice) + { +#ifdef __CUDA + cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op"); + cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op"); + cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, + cutransA, + cutransB, + m, + n, + k, + (double2*)&alpha, + (double2*)a, + lda, + (double2*)b, + ldb, + (double2*)&beta, + (double2*)c, + ldc)); +#endif + } } // Symm and Hemm part. Only col-major is supported. diff --git a/source/module_base/kernels/dsp/dsp_connector.cpp b/source/module_base/kernels/dsp/dsp_connector.cpp index 5b62847ac6..a3c5f6d897 100644 --- a/source/module_base/kernels/dsp/dsp_connector.cpp +++ b/source/module_base/kernels/dsp/dsp_connector.cpp @@ -1,201 +1,335 @@ #include "dsp_connector.h" -#include + #include +#include extern "C" { - #define complex_double ignore_complex_double - #include // MTBLAS_TRANSPOSE etc - #undef complex_double - #include // gemm +#define complex_double ignore_complex_double +#include // MTBLAS_TRANSPOSE etc +#undef complex_double +#include // gemm } - -void dspInitHandle(int id){ - mt_blas_init(id); - std::cout << " ** DSP inited on cluster "<< id << " **" << std::endl; +namespace mtfunc +{ +void dspInitHandle(int id) +{ + mt_blas_init(id); + std::cout << " ** DSP inited on cluster " << id << " **" << std::endl; } // Use this at the beginning of the program to start a dsp cluster -void dspDestoryHandle(int id){ - hthread_dev_close(id); - std::cout << " ** DSP closed on cluster "<< id << " **" << std::endl; +void dspDestoryHandle(int id) +{ + hthread_dev_close(id); + std::cout << " ** DSP closed on cluster " << id << " **" << std::endl; } // Close dsp cluster at the end - -MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) { - switch (blasTrans[0]) { - case 'N': - case 'n': - return MtblasNoTrans; - case 'T': - case 't': - return MtblasTrans; - case 'C': - case 'c': - return MtblasConjTrans; - default: - std::cout << "Invalid BLAS transpose parameter!! Use default instead." << std::endl; - return MtblasNoTrans; +MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) +{ + switch (blasTrans[0]) + { + case 'N': + case 'n': + return MtblasNoTrans; + case 'T': + case 't': + return MtblasTrans; + case 'C': + case 'c': + return MtblasConjTrans; + default: + std::cout << "Invalid BLAS transpose parameter!! Use default instead." << std::endl; + return MtblasNoTrans; } } // Used to convert normal transpost char to mtblas transpose flag - void* malloc_ht(size_t bytes, int cluster_id) { - //std::cout << "MALLOC " << cluster_id; - void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW); - //std::cout << ptr << " SUCCEED" << std::endl;; - return ptr; + // std::cout << "MALLOC " << cluster_id; + void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW); + // std::cout << ptr << " SUCCEED" << std::endl;; + return ptr; } // Used to replace original malloc void free_ht(void* ptr) { - //std::cout << "FREE " << ptr; - hthread_free(ptr); - //std::cout << " FREE SUCCEED" << std::endl; + // std::cout << "FREE " << ptr; + hthread_free(ptr); + // std::cout << " FREE SUCCEED" << std::endl; } // Used to replace original free -void sgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const float *alpha, const float *a, const int *lda, - const float *b, const int *ldb, const float *beta, - float *c, const int *ldc, int cluster_id) +void sgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc, + int cluster_id) { - mtblas_sgemm(MTBLAS_ORDER::MtblasColMajor, - convertBLASTranspose(transa),convertBLASTranspose(transb), - *m,*n,*k, - *alpha, a, *lda, - b, *ldb, *beta, - c, *ldc, cluster_id - ); + mtblas_sgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + *alpha, + a, + *lda, + b, + *ldb, + *beta, + c, + *ldc, + cluster_id); } // zgemm that needn't malloc_ht or free_ht -void dgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const double *alpha, const double *a, const int *lda, - const double *b, const int *ldb, const double *beta, - double *c, const int *ldc, int cluster_id) +void dgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc, + int cluster_id) { - mtblas_dgemm(MTBLAS_ORDER::MtblasColMajor, - convertBLASTranspose(transa),convertBLASTranspose(transb), - *m,*n,*k, - *alpha, a, *lda, - b, *ldb, *beta, - c, *ldc, cluster_id - ); + mtblas_dgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + *alpha, + a, + *lda, + b, + *ldb, + *beta, + c, + *ldc, + cluster_id); } // cgemm that needn't malloc_ht or free_ht -void zgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id) +void zgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id) { - mtblas_zgemm(MTBLAS_ORDER::MtblasColMajor, - convertBLASTranspose(transa),convertBLASTranspose(transb), - *m,*n,*k, - (const void*)alpha, (const void*)a, *lda, - (const void*)b, *ldb, (const void*)beta, - (void*)c, *ldc, cluster_id - ); + mtblas_zgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + (const void*)alpha, + (const void*)a, + *lda, + (const void*)b, + *ldb, + (const void*)beta, + (void*)c, + *ldc, + cluster_id); } // zgemm that needn't malloc_ht or free_ht -void cgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id) +void cgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id) { - mtblas_cgemm(MTBLAS_ORDER::MtblasColMajor, - convertBLASTranspose(transa),convertBLASTranspose(transb), - *m,*n,*k, - (const void*)alpha, (const void*)a, *lda, - (const void*)b, *ldb, (const void*)beta, - (void*)c, *ldc, cluster_id - ); + mtblas_cgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + (const void*)alpha, + (const void*)a, + *lda, + (const void*)b, + *ldb, + (const void*)beta, + (void*)c, + *ldc, + cluster_id); } // cgemm that needn't malloc_ht or free_ht // Used to replace original free -void sgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const float *alpha, const float *a, const int *lda, - const float *b, const int *ldb, const float *beta, - float *c, const int *ldc, int cluster_id) +void sgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc, + int cluster_id) { - mt_hthread_sgemm(MTBLAS_ORDER::MtblasColMajor, - convertBLASTranspose(transa),convertBLASTranspose(transb), - *m,*n,*k, - *alpha, a, *lda, - b, *ldb, *beta, - c, *ldc, cluster_id - ); + mt_hthread_sgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + *alpha, + a, + *lda, + b, + *ldb, + *beta, + c, + *ldc, + cluster_id); } // zgemm that needn't malloc_ht or free_ht -void dgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const double *alpha, const double *a, const int *lda, - const double *b, const int *ldb, const double *beta, - double *c, const int *ldc, int cluster_id) +void dgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc, + int cluster_id) { - mt_hthread_dgemm(MTBLAS_ORDER::MtblasColMajor, - convertBLASTranspose(transa),convertBLASTranspose(transb), - *m,*n,*k, - *alpha, a, *lda, - b, *ldb, *beta, - c, *ldc, cluster_id - ); + mt_hthread_dgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + *alpha, + a, + *lda, + b, + *ldb, + *beta, + c, + *ldc, + cluster_id); } // cgemm that needn't malloc_ht or free_ht -void zgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, - const std::complex *a, - const int *lda, - const std::complex *b, - const int *ldb, - const std::complex *beta, - std::complex *c, - const int *ldc, - int cluster_id) +void zgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id) { - std::complex* alp = (std::complex*) malloc_ht(sizeof(std::complex), cluster_id); - *alp = *alpha; - std::complex* bet = (std::complex*) malloc_ht(sizeof(std::complex), cluster_id); - *bet = *beta; - mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor, - convertBLASTranspose(transa),convertBLASTranspose(transb), - *m,*n,*k, - alp, a, *lda, - b, *ldb, bet, - c, *ldc, cluster_id - ); - + std::complex* alp = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *alp = *alpha; + std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *bet = *beta; + mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + alp, + a, + *lda, + b, + *ldb, + bet, + c, + *ldc, + cluster_id); } // zgemm that needn't malloc_ht or free_ht -void cgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id) +void cgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id) { - std::complex* alp = (std::complex*) malloc_ht(sizeof(std::complex), cluster_id); - *alp = *alpha; - std::complex* bet = (std::complex*) malloc_ht(sizeof(std::complex), cluster_id); - *bet = *beta; - - mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor, - convertBLASTranspose(transa),convertBLASTranspose(transb), - *m,*n,*k, - (const void*)alp, (const void*)a, *lda, - (const void*)b, *ldb, (const void*)bet, - (void*)c, *ldc, cluster_id - ); - - free_ht(alp); - free_ht(bet); -} // cgemm that needn't malloc_ht or free_ht \ No newline at end of file + std::complex* alp = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *alp = *alpha; + std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *bet = *beta; + + mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + (const void*)alp, + (const void*)a, + *lda, + (const void*)b, + *ldb, + (const void*)bet, + (void*)c, + *ldc, + cluster_id); + + free_ht(alp); + free_ht(bet); +} // cgemm that needn't malloc_ht or free_ht +} // namespace mtfunc \ No newline at end of file diff --git a/source/module_base/kernels/dsp/dsp_connector.h b/source/module_base/kernels/dsp/dsp_connector.h index ea0d17749e..bbda25f798 100644 --- a/source/module_base/kernels/dsp/dsp_connector.h +++ b/source/module_base/kernels/dsp/dsp_connector.h @@ -6,95 +6,157 @@ #include "module_base/module_device/memory_op.h" #include "module_hsolver/diag_comm_info.h" +namespace mtfunc +{ // Base dsp functions void dspInitHandle(int id); void dspDestoryHandle(int id); -void *malloc_ht(size_t bytes, int cluster_id); +void* malloc_ht(size_t bytes, int cluster_id); void free_ht(void* ptr); - // mtblas functions -void sgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const float *alpha, const float *a, const int *lda, - const float *b, const int *ldb, const float *beta, - float *c, const int *ldc, int cluster_id); - -void dgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const double *alpha,const double *a, const int *lda, - const double *b, const int *ldb, const double *beta, - double *c, const int *ldc, int cluster_id); - -void zgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id); - -void cgemm_mt_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id); - - -void sgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const float *alpha, const float *a, const int *lda, - const float *b, const int *ldb, const float *beta, - float *c, const int *ldc, int cluster_id); - -void dgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const double *alpha,const double *a, const int *lda, - const double *b, const int *ldb, const double *beta, - double *c, const int *ldc, int cluster_id); - -void zgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id); - -void cgemm_mth_(const char *transa, const char *transb, - const int *m, const int *n, const int *k, - const std::complex *alpha, const std::complex *a, const int *lda, - const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc, int cluster_id); - -//#define zgemm_ zgemm_mt +void sgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc, + int cluster_id); + +void dgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc, + int cluster_id); + +void zgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); + +void cgemm_mt_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); + +void sgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const float* alpha, + const float* a, + const int* lda, + const float* b, + const int* ldb, + const float* beta, + float* c, + const int* ldc, + int cluster_id); + +void dgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const double* alpha, + const double* a, + const int* lda, + const double* b, + const int* ldb, + const double* beta, + double* c, + const int* ldc, + int cluster_id); + +void zgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); + +void cgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); + +// #define zgemm_ zgemm_mt // The next is dsp utils. It may be moved to other files if this file get too huge template -void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv, MPI_Comm diag_comm){ +void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv, MPI_Comm diag_comm) +{ - using syncmem_complex_op = base_device::memory::synchronize_memory_op; + using syncmem_complex_op + = base_device::memory::synchronize_memory_op; - auto* swap = new T[notconv * nbase_x]; + auto* swap = new T[notconv * nbase_x]; auto* target = new T[notconv * nbase_x]; syncmem_complex_op()(swap, hcc + nbase * nbase_x, notconv * nbase_x); if (base_device::get_current_precision(swap) == "single") { - MPI_Reduce(swap, - target, - notconv * nbase_x, - MPI_COMPLEX, - MPI_SUM, - 0, - diag_comm); + MPI_Reduce(swap, target, notconv * nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm); } else { - MPI_Reduce(swap, - target, - notconv * nbase_x, - MPI_DOUBLE_COMPLEX, - MPI_SUM, - 0, - diag_comm); + MPI_Reduce(swap, target, notconv * nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm); } syncmem_complex_op()(hcc + nbase * nbase_x, target, notconv * nbase_x); @@ -102,30 +164,18 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv if (base_device::get_current_precision(swap) == "single") { - MPI_Reduce(swap, - target, - notconv * nbase_x, - MPI_COMPLEX, - MPI_SUM, - 0, - diag_comm); + MPI_Reduce(swap, target, notconv * nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm); } else { - MPI_Reduce(swap, - target, - notconv * nbase_x, - MPI_DOUBLE_COMPLEX, - MPI_SUM, - 0, - diag_comm); + MPI_Reduce(swap, target, notconv * nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm); } syncmem_complex_op()(scc + nbase * nbase_x, target, notconv * nbase_x); delete[] swap; delete[] target; } - +} // namespace mtfunc #endif #endif \ No newline at end of file diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index 525ecee89f..9af0ce5a79 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -340,9 +340,9 @@ struct resize_memory_op_mt { if (arr != nullptr) { - free_ht(arr); + mtfunc::free_ht(arr); } - arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK); + arr = (FPTYPE*)mtfunc::malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK); std::string record_string; if (record_in != nullptr) { @@ -365,7 +365,7 @@ struct delete_memory_op_mt { void operator()(FPTYPE* arr) { - free_ht(arr); + mtfunc::free_ht(arr); } }; diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index a35005e48f..83c2d02466 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -1,9 +1,10 @@ -#include #include "fft_bundle.h" #include "module_base/module_device/device.h" #include "module_base/module_device/memory_op.h" #include "module_base/tool_quit.h" + +#include #if defined(__CUDA) #include "fft_cuda.h" #endif @@ -13,8 +14,8 @@ #if defined(__DSP) #include "fft_dsp.h" #endif -template -std::unique_ptr make_unique(Args &&... args) +template +std::unique_ptr make_unique(Args&&... args) { return std::unique_ptr(new FFT_BASE(std::forward(args)...)); } @@ -25,216 +26,264 @@ FFT_Bundle::~FFT_Bundle() this->clear(); } -void FFT_Bundle::setfft(std::string device_in,std::string precision_in) +void FFT_Bundle::setfft(std::string device_in, std::string precision_in) { this->device = device_in; this->precision = precision_in; } -void FFT_Bundle::initfft(int nx_in, - int ny_in, - int nz_in, - int lixy_in, - int rixy_in, - int ns_in, - int nplane_in, - int nproc_in, - bool gamma_only_in, - bool xprime_in , +void FFT_Bundle::initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in, bool mpifft_in) { - assert(this->device=="cpu" || this->device=="gpu" || this->device=="dsp"); - assert(this->precision=="single" || this->precision=="double" || this->precision=="mixing"); + assert(this->device == "cpu" || this->device == "gpu" || this->device == "dsp"); + assert(this->precision == "single" || this->precision == "double" || this->precision == "mixing"); - if (this->precision=="single") + if (this->precision == "single") { - #if not defined (__ENABLE_FLOAT_FFTW) - if (this->device == "cpu"){ +#if not defined(__ENABLE_FLOAT_FFTW) + if (this->device == "cpu") + { float_define = false; } - #endif - #if defined(__CUDA) || defined (__ROCM) - if (this->device == "gpu"){ +#endif +#if defined(__CUDA) || defined(__ROCM) + if (this->device == "gpu") + { float_flag = float_define; } - #endif +#endif float_flag = float_define; double_flag = true; } - if (this->precision=="double") + if (this->precision == "double") { double_flag = true; } - #if defined(__DSP) - if (device=="dsp") +#if defined(__DSP) + if (device == "dsp") + { + if (float_flag) { - if (float_flag) - ModuleBase::WARNING_QUIT("device","now dsp fft is not support for the float type"); - fft_double=make_unique>(); - fft_double->initfft(nx_in,ny_in,nz_in); + ModuleBase::WARNING_QUIT("device", "now dsp fft is not supported for the float type"); } - #endif - if (device=="cpu") + fft_double = make_unique>(); + fft_double->initfft(nx_in, ny_in, nz_in); + } +#endif + if (device == "cpu") { fft_float = make_unique>(this->fft_mode); fft_double = make_unique>(this->fft_mode); if (float_flag) { - fft_float->initfft(nx_in, - ny_in, - nz_in, - lixy_in, - rixy_in, - ns_in, - nplane_in, - nproc_in, - gamma_only_in, - xprime_in); + fft_float + ->initfft(nx_in, ny_in, nz_in, lixy_in, rixy_in, ns_in, nplane_in, nproc_in, gamma_only_in, xprime_in); } if (double_flag) { - fft_double->initfft(nx_in, - ny_in, - nz_in, - lixy_in, - rixy_in, - ns_in, - nplane_in, - nproc_in, - gamma_only_in, - xprime_in); + fft_double + ->initfft(nx_in, ny_in, nz_in, lixy_in, rixy_in, ns_in, nplane_in, nproc_in, gamma_only_in, xprime_in); } } - if (device=="gpu") + if (device == "gpu") { - #if defined(__ROCM) - fft_float = make_unique>(); - fft_float->initfft(nx_in,ny_in,nz_in); - fft_double = make_unique>(); - fft_double->initfft(nx_in,ny_in,nz_in); - #elif defined(__CUDA) - fft_float = make_unique>(); - fft_float->initfft(nx_in,ny_in,nz_in); - fft_double = make_unique>(); - fft_double->initfft(nx_in,ny_in,nz_in); - #endif +#if defined(__ROCM) + fft_float = make_unique>(); + fft_float->initfft(nx_in, ny_in, nz_in); + fft_double = make_unique>(); + fft_double->initfft(nx_in, ny_in, nz_in); +#elif defined(__CUDA) + fft_float = make_unique>(); + fft_float->initfft(nx_in, ny_in, nz_in); + fft_double = make_unique>(); + fft_double->initfft(nx_in, ny_in, nz_in); +#endif } - } void FFT_Bundle::setupFFT() { - if (double_flag){fft_double->setupFFT();} - if (float_flag) {fft_float->setupFFT();} + if (double_flag) + { + fft_double->setupFFT(); + } + if (float_flag) + { + fft_float->setupFFT(); + } } void FFT_Bundle::clearFFT() { - if (double_flag){fft_double->cleanFFT();} - if (float_flag) {fft_float->cleanFFT();} + if (double_flag) + { + fft_double->cleanFFT(); + } + if (float_flag) + { + fft_float->cleanFFT(); + } } void FFT_Bundle::clear() { this->clearFFT(); - if (double_flag){fft_double->clear();} - if (float_flag) {fft_float->clear();} -} - -template <> void -FFT_Bundle::fftxyfor(std::complex* in, - std::complex* out) -const {fft_float->fftxyfor(in,out);} -template <> void -FFT_Bundle::fftxyfor(std::complex* in, - std::complex* out) -const {fft_double->fftxyfor(in,out);} - - -template <> void -FFT_Bundle::fftzfor(std::complex* in, - std::complex* out) -const {fft_float->fftzfor(in,out);} -template <> void -FFT_Bundle::fftzfor(std::complex* in, - std::complex* out) -const {fft_double->fftzfor(in,out);} - -template <> void -FFT_Bundle::fftxybac(std::complex* in, - std::complex* out) -const {fft_float->fftxybac(in,out);} -template <> void -FFT_Bundle::fftxybac(std::complex* in, - std::complex* out) -const {fft_double->fftxybac(in,out);} - -template <> void -FFT_Bundle::fftzbac(std::complex* in, - std::complex* out) -const {fft_float->fftzbac(in,out);} -template <> void -FFT_Bundle::fftzbac(std::complex* in, - std::complex* out) -const {fft_double->fftzbac(in,out);} - -template <> void -FFT_Bundle::fftxyr2c(float* in, - std::complex* out) -const {fft_float->fftxyr2c(in,out);} -template <> void -FFT_Bundle::fftxyr2c(double* in, - std::complex* out) -const {fft_double->fftxyr2c(in,out);} - -template <> void -FFT_Bundle::fftxyc2r(std::complex* in, - float* out) -const {fft_float->fftxyc2r(in,out);} -template <> void -FFT_Bundle::fftxyc2r(std::complex* in, - double* out) -const {fft_double->fftxyc2r(in,out);} - -template <> void -FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, - std::complex* in, - std::complex* out) -const {fft_float->fft3D_forward(in, out);} -template <> void -FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, - std::complex* in, - std::complex* out) -const {fft_double->fft3D_forward(in, out);} - -template <> void -FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, - std::complex* in, - std::complex* out) -const {fft_float->fft3D_backward(in, out);} -template <> void -FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, - std::complex* in, - std::complex* out) -const {fft_double->fft3D_backward(in, out);} + if (double_flag) + { + fft_double->clear(); + } + if (float_flag) + { + fft_float->clear(); + } +} + +template <> +void FFT_Bundle::fftxyfor(std::complex* in, std::complex* out) const +{ + fft_float->fftxyfor(in, out); +} +template <> +void FFT_Bundle::fftxyfor(std::complex* in, std::complex* out) const +{ + fft_double->fftxyfor(in, out); +} + +template <> +void FFT_Bundle::fftzfor(std::complex* in, std::complex* out) const +{ + fft_float->fftzfor(in, out); +} +template <> +void FFT_Bundle::fftzfor(std::complex* in, std::complex* out) const +{ + fft_double->fftzfor(in, out); +} + +template <> +void FFT_Bundle::fftxybac(std::complex* in, std::complex* out) const +{ + fft_float->fftxybac(in, out); +} +template <> +void FFT_Bundle::fftxybac(std::complex* in, std::complex* out) const +{ + fft_double->fftxybac(in, out); +} + +template <> +void FFT_Bundle::fftzbac(std::complex* in, std::complex* out) const +{ + fft_float->fftzbac(in, out); +} +template <> +void FFT_Bundle::fftzbac(std::complex* in, std::complex* out) const +{ + fft_double->fftzbac(in, out); +} + +template <> +void FFT_Bundle::fftxyr2c(float* in, std::complex* out) const +{ + fft_float->fftxyr2c(in, out); +} +template <> +void FFT_Bundle::fftxyr2c(double* in, std::complex* out) const +{ + fft_double->fftxyr2c(in, out); +} + +template <> +void FFT_Bundle::fftxyc2r(std::complex* in, float* out) const +{ + fft_float->fftxyc2r(in, out); +} +template <> +void FFT_Bundle::fftxyc2r(std::complex* in, double* out) const +{ + fft_double->fftxyc2r(in, out); +} + +template <> +void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const +{ + fft_float->fft3D_forward(in, out); +} +template <> +void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const +{ + fft_double->fft3D_forward(in, out); +} + +template <> +void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const +{ + fft_float->fft3D_backward(in, out); +} +template <> +void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, + std::complex* in, + std::complex* out) const +{ + fft_double->fft3D_backward(in, out); +} // access the real space data -template <> float* -FFT_Bundle::get_rspace_data() const {return fft_float->get_rspace_data();} -template <> double* -FFT_Bundle::get_rspace_data() const {return fft_double->get_rspace_data();} - -template <> std::complex* -FFT_Bundle::get_auxr_data() const {return fft_float->get_auxr_data();} -template <> std::complex* -FFT_Bundle::get_auxr_data() const {return fft_double->get_auxr_data();} - -template <> std::complex* -FFT_Bundle::get_auxg_data() const {return fft_float->get_auxg_data();} -template <> std::complex* -FFT_Bundle::get_auxg_data() const {return fft_double->get_auxg_data();} - -template <> std::complex* -FFT_Bundle::get_auxr_3d_data() const {return fft_float->get_auxr_3d_data();} -template <> std::complex* -FFT_Bundle::get_auxr_3d_data() const {return fft_double->get_auxr_3d_data();} -} \ No newline at end of file +template <> +float* FFT_Bundle::get_rspace_data() const +{ + return fft_float->get_rspace_data(); +} +template <> +double* FFT_Bundle::get_rspace_data() const +{ + return fft_double->get_rspace_data(); +} + +template <> +std::complex* FFT_Bundle::get_auxr_data() const +{ + return fft_float->get_auxr_data(); +} +template <> +std::complex* FFT_Bundle::get_auxr_data() const +{ + return fft_double->get_auxr_data(); +} + +template <> +std::complex* FFT_Bundle::get_auxg_data() const +{ + return fft_float->get_auxg_data(); +} +template <> +std::complex* FFT_Bundle::get_auxg_data() const +{ + return fft_double->get_auxg_data(); +} + +template <> +std::complex* FFT_Bundle::get_auxr_3d_data() const +{ + return fft_float->get_auxr_3d_data(); +} +template <> +std::complex* FFT_Bundle::get_auxr_3d_data() const +{ + return fft_double->get_auxr_3d_data(); +} +} // namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 71ce5192f3..58851e139d 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -1,215 +1,207 @@ #ifndef FFT_TEMP_H #define FFT_TEMP_H -#include #include "fft_base.h" #include "fft_cpu.h" + +#include namespace ModulePW { class FFT_Bundle { - public: - FFT_Bundle(){}; - ~FFT_Bundle(); - /** - * @brief Constructor with device and precision. - * @param device_in device type, cpu or gpu. - * @param precision_in precision type, single or double. - * - * the function will check the input device and precision, - * and set the device and precision. - */ - FFT_Bundle(std::string device_in,std::string precision_in) - :device(device_in),precision(precision_in){}; - - /** - * @brief Set device and precision. - * @param device_in device type, cpu or gpu. - * @param precision_in precision type, single or double. - * - * the function will check the input device and precision, - * and set the device and precision. - */ - void setfft(std::string device_in,std::string precision_in); + public: + FFT_Bundle() {}; + ~FFT_Bundle(); + /** + * @brief Constructor with device and precision. + * @param device_in device type, cpu or gpu. + * @param precision_in precision type, single or double. + * + * the function will check the input device and precision, + * and set the device and precision. + */ + FFT_Bundle(std::string device_in, std::string precision_in) : device(device_in), precision(precision_in) {}; + + /** + * @brief Set device and precision. + * @param device_in device type, cpu or gpu. + * @param precision_in precision type, single or double. + * + * the function will check the input device and precision, + * and set the device and precision. + */ + void setfft(std::string device_in, std::string precision_in); + + /** + * @brief Initialize the fft parameters. + * @param nx_in number of grid points in x direction. + * @param ny_in number of grid points in y direction. + * @param nz_in number of grid points in z direction. + * @param lixy_in the position of the left boundary + * in the x-y plane. + * @param rixy_in the position of the right boundary + * in the x-y plane. + * @param ns_in number of stick whcih is used in the + * Z direction. + * @param nplane_in number of x-y planes. + * @param nproc_in number of processors. + * @param gamma_only_in whether only gamma point is used. + * @param xprime_in whether xprime is used. + * + * the function will initialize the many-fft parameters + * Wheatley in cpu or gpu device. + */ + void initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in = true, + bool mpifft_in = false); + + /** + * @brief Initialize the fft mode. + * @param fft_mode_in fft mode. + * + * the function will initialize the fft mode. + */ - /** - * @brief Initialize the fft parameters. - * @param nx_in number of grid points in x direction. - * @param ny_in number of grid points in y direction. - * @param nz_in number of grid points in z direction. - * @param lixy_in the position of the left boundary - * in the x-y plane. - * @param rixy_in the position of the right boundary - * in the x-y plane. - * @param ns_in number of stick whcih is used in the - * Z direction. - * @param nplane_in number of x-y planes. - * @param nproc_in number of processors. - * @param gamma_only_in whether only gamma point is used. - * @param xprime_in whether xprime is used. - * - * the function will initialize the many-fft parameters - * Wheatley in cpu or gpu device. - */ - void initfft(int nx_in, - int ny_in, - int nz_in, - int lixy_in, - int rixy_in, - int ns_in, - int nplane_in, - int nproc_in, - bool gamma_only_in, - bool xprime_in = true, - bool mpifft_in = false); - - /** - * @brief Initialize the fft mode. - * @param fft_mode_in fft mode. - * - * the function will initialize the fft mode. - */ + void initfftmode(int fft_mode_in) + { + this->fft_mode = fft_mode_in; + } - void initfftmode(int fft_mode_in){this->fft_mode = fft_mode_in;} + void setupFFT(); - void setupFFT(); + void clearFFT(); - void clearFFT(); - - void clear(); - - /** - * @brief Get the real space data. - * @return FPTYPE* the real space data. - * - * the function will return the real space data, - * which is used in the cpu-like fft. - */ - template - FPTYPE* get_rspace_data() const; - /** - * @brief Get the auxr data. - * @return std::complex* the auxr data. - * - * the function will return the auxr data, - * which is used in the cpu-like fft. - */ - template - std::complex* get_auxr_data() const; - /** - * @brief Get the auxg data. - * @return std::complex* the auxg data. - * - * the function will return the auxg data, - * which is used in the cpu-like fft. - */ - template - std::complex* get_auxg_data() const; - /** - * @brief Get the auxr 3d data. - * @return std::complex* the auxr 3d data. - * - * the function will return the auxr 3d data, - * which is used in the gpu-like fft. - */ - template - std::complex* get_auxr_3d_data() const; - - /** - * @brief Forward fft in z direction. - * @param in input data. - * @param out output data. - * - * The function will do the forward many fft in z direction, - * As an interface, the function will call the fftzfor in the - * accurate fft class. - * which is used in the cpu-like fft. - */ - template - void fftzfor(std::complex* in, - std::complex* out) const; - /** - * @brief Forward fft in x-y direction. - * @param in input data. - * @param out output data. - * - * the function will do the forward fft in x and y direction, - * which is used in the cpu-like fft.As an interface, - * the function will call the fftxyfor in the accurate fft class. - */ - template - void fftxyfor(std::complex* in, - std::complex* out) const; - /** - * @brief Backward fft in z direction. - * @param in input data. - * @param out output data. - * - * the function will do the backward many fft in z direction, - * which is used in the cpu-like fft.As an interface, - * the function will call the fftzbac in the accurate fft class. - */ - template - void fftzbac(std::complex* in, - std::complex* out) const; - /** - * @brief Backward fft in x-y direction. - * @param in input data. - * @param out output data. - * - * the function will do the backward fft in x and y direction, - * which is used in the cpu-like fft.As an interface, - * the function will call the fftxybac in the accurate fft class. - */ - template - void fftxybac(std::complex* in, - std::complex* out) const; - - /** - * @brief Real to complex fft in x-y direction. - * @param in input data. - * @param out output data. - * - * the function will do the real to complex fft in x and y direction, - * which is used in the cpu-like fft.As an interface, - * the function will call the fftxyr2c in the accurate fft class. - */ - template - void fftxyr2c(FPTYPE* in, - std::complex* out) const; - /** - * @brief Complex to real fft in x-y direction. - * @param in input data. - * @param out output data. - * - * the function will do the complex to real fft in x and y direction, - * which is used in the cpu-like fft.As an interface, - * the function will call the fftxyc2r in the accurate fft class. - */ - template - void fftxyc2r(std::complex* in, - FPTYPE* out) const; + void clear(); - template - void fft3D_forward(const Device* ctx, - std::complex* in, - std::complex* out) const; - template - void fft3D_backward(const Device* ctx, - std::complex* in, - std::complex* out) const; + /** + * @brief Get the real space data. + * @return FPTYPE* the real space data. + * + * the function will return the real space data, + * which is used in the cpu-like fft. + */ + template + FPTYPE* get_rspace_data() const; + /** + * @brief Get the auxr data. + * @return std::complex* the auxr data. + * + * the function will return the auxr data, + * which is used in the cpu-like fft. + */ + template + std::complex* get_auxr_data() const; + /** + * @brief Get the auxg data. + * @return std::complex* the auxg data. + * + * the function will return the auxg data, + * which is used in the cpu-like fft. + */ + template + std::complex* get_auxg_data() const; + /** + * @brief Get the auxr 3d data. + * @return std::complex* the auxr 3d data. + * + * the function will return the auxr 3d data, + * which is used in the gpu-like fft. + */ + template + std::complex* get_auxr_3d_data() const; - private: - int fft_mode = 0; - bool float_flag=false; - bool float_define=true; - bool double_flag=false; - std::shared_ptr> fft_float=nullptr; - std::shared_ptr> fft_double=nullptr; - - std::string device = "cpu"; - std::string precision = "double"; -}; + /** + * @brief Forward fft in z direction. + * @param in input data. + * @param out output data. + * + * The function will do the forward many fft in z direction, + * As an interface, the function will call the fftzfor in the + * accurate fft class. + * which is used in the cpu-like fft. + */ + template + void fftzfor(std::complex* in, std::complex* out) const; + /** + * @brief Forward fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the forward fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyfor in the accurate fft class. + */ + template + void fftxyfor(std::complex* in, std::complex* out) const; + /** + * @brief Backward fft in z direction. + * @param in input data. + * @param out output data. + * + * the function will do the backward many fft in z direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftzbac in the accurate fft class. + */ + template + void fftzbac(std::complex* in, std::complex* out) const; + /** + * @brief Backward fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the backward fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxybac in the accurate fft class. + */ + template + void fftxybac(std::complex* in, std::complex* out) const; + + /** + * @brief Real to complex fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the real to complex fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyr2c in the accurate fft class. + */ + template + void fftxyr2c(FPTYPE* in, std::complex* out) const; + /** + * @brief Complex to real fft in x-y direction. + * @param in input data. + * @param out output data. + * + * the function will do the complex to real fft in x and y direction, + * which is used in the cpu-like fft.As an interface, + * the function will call the fftxyc2r in the accurate fft class. + */ + template + void fftxyc2r(std::complex* in, FPTYPE* out) const; + + template + void fft3D_forward(const Device* ctx, std::complex* in, std::complex* out) const; + template + void fft3D_backward(const Device* ctx, std::complex* in, std::complex* out) const; + + private: + int fft_mode = 0; + bool float_flag = false; + bool float_define = true; + bool double_flag = false; + std::shared_ptr> fft_float = nullptr; + std::shared_ptr> fft_double = nullptr; + + std::string device = "cpu"; + std::string precision = "double"; +}; } // namespace ModulePW #endif // FFT_H - diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.cpp b/source/module_basis/module_pw/module_fft/fft_dsp.cpp index 4ff838cd52..3428af345f 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.cpp +++ b/source/module_basis/module_pw/module_fft/fft_dsp.cpp @@ -1,71 +1,72 @@ #include "fft_dsp.h" -#include + +#include "module_base/global_variable.h" + #include +#include #include -#include "module_base/global_variable.h" namespace ModulePW { -template<> -void FFT_DSP::initfft(int nx_in,int ny_in,int nz_in) +template <> +void FFT_DSP::initfft(int nx_in, int ny_in, int nz_in) { - this->nx=nx_in; - this->ny=ny_in; - this->nz=nz_in; + this->nx = nx_in; + this->ny = ny_in; + this->nz = nz_in; cluster_id = GlobalV::MY_RANK; - nxyz=this->nx*this->ny*this->nz; + nxyz = this->nx * this->ny * this->nz; } -template<> +template <> void FFT_DSP::setupFFT() { PROBLEM pbm_forward; PROBLEM pbm_backward; PLAN* ptr_plan_forward; PLAN* ptr_plan_backward; - INT num_thread=8; - INT size; + INT num_thread = 8; + INT size; hthread_dat_load(cluster_id, FFT_DAT_DIR); - - //compute the size of and malloc thread - size = nx*ny*nz*2*sizeof(E); - forward_in = (E*)hthread_malloc((int)cluster_id, size, HT_MEM_RW); -// // //init 3d fft problem - pbm_forward.num_dim = 3; - pbm_forward.n[0] = nx; - pbm_forward.n[1] = ny; - pbm_forward.n[2] = nz; - pbm_forward.iFFT = 0; - pbm_forward.in = forward_in; - pbm_forward.out = forward_in; - -// // //make ptr plan + // compute the size of and malloc thread + size = nx * ny * nz * 2 * sizeof(E); + forward_in = (E*)hthread_malloc((int)cluster_id, size, HT_MEM_RW); + + // // //init 3d fft problem + pbm_forward.num_dim = 3; + pbm_forward.n[0] = nx; + pbm_forward.n[1] = ny; + pbm_forward.n[2] = nz; + pbm_forward.iFFT = 0; + pbm_forward.in = forward_in; + pbm_forward.out = forward_in; + + // // //make ptr plan make_plan(&pbm_forward, &ptr_plan_forward, cluster_id, num_thread); ptr_plan_forward->in = forward_in; ptr_plan_forward->out = forward_in; args_for[1] = (unsigned long)ptr_plan_forward; - //init 3d fft problem - pbm_backward.num_dim = 3; - pbm_backward.n[0] = nx; - pbm_backward.n[1] = ny; - pbm_backward.n[2] = nz; - pbm_backward.iFFT = 1; - pbm_backward.in = forward_in; - pbm_backward.out = forward_in; + // init 3d fft problem + pbm_backward.num_dim = 3; + pbm_backward.n[0] = nx; + pbm_backward.n[1] = ny; + pbm_backward.n[2] = nz; + pbm_backward.iFFT = 1; + pbm_backward.in = forward_in; + pbm_backward.out = forward_in; make_plan(&pbm_backward, &ptr_plan_backward, cluster_id, num_thread); ptr_plan_backward->in = forward_in; ptr_plan_backward->out = forward_in; - args_back[1]=(unsigned long)ptr_plan_backward; -} + args_back[1] = (unsigned long)ptr_plan_backward; +} -template<> -void FFT_DSP::fft3D_forward(std::complex* in, - std::complex* out) const +template <> +void FFT_DSP::fft3D_forward(std::complex* in, std::complex* out) const { - INT num_thread=8; + INT num_thread = 8; thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL); - //create b_id for the barrier + // create b_id for the barrier b_id = hthread_barrier_create(cluster_id); args_for[0] = b_id; hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for); @@ -74,45 +75,43 @@ void FFT_DSP::fft3D_forward(std::complex* in, hthread_group_destroy(thread_id_for); } -template<> -void FFT_DSP::fft3D_backward(std::complex * in, - std::complex* out) const +template <> +void FFT_DSP::fft3D_backward(std::complex* in, std::complex* out) const { - INT num_thread=8; + INT num_thread = 8; thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL); - //create b_id for the barrier + // create b_id for the barrier b_id = hthread_barrier_create(cluster_id); - args_back[0] =b_id; + args_back[0] = b_id; hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back); hthread_group_wait(thread_id_for); hthread_barrier_destroy(b_id); hthread_group_destroy(thread_id_for); - } -template<> +template <> void FFT_DSP::cleanFFT() { - if (ptr_plan_forward!=nullptr) + if (ptr_plan_forward != nullptr) { destroy_plan(ptr_plan_forward); - ptr_plan_forward=nullptr; + ptr_plan_forward = nullptr; } - if (ptr_plan_backward!=nullptr) + if (ptr_plan_backward != nullptr) { destroy_plan(ptr_plan_backward); - ptr_plan_backward=nullptr; + ptr_plan_backward = nullptr; } } -template<> +template <> void FFT_DSP::clear() { this->cleanFFT(); hthread_free(forward_in); } -template<> std::complex* -FFT_DSP::get_auxr_3d_data() const +template <> +std::complex* FFT_DSP::get_auxr_3d_data() const { return reinterpret_cast*>(this->forward_in); } @@ -120,4 +119,4 @@ template FFT_DSP::FFT_DSP(); template FFT_DSP::~FFT_DSP(); template FFT_DSP::FFT_DSP(); template FFT_DSP::~FFT_DSP(); -} \ No newline at end of file +} // namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index d9eecf8ae1..08391242ea 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -1,18 +1,18 @@ #include "pw_basis_k.h" -#include - #include "module_base/constants.h" #include "module_base/memory.h" #include "module_base/timer.h" #include "module_parameter/parameter.h" + +#include namespace ModulePW { PW_Basis_K::PW_Basis_K() { - classname="PW_Basis_K"; - this->fft_bundle.setfft(this->device,this->precision); + classname = "PW_Basis_K"; + this->fft_bundle.setfft(this->device, this->precision); } PW_Basis_K::~PW_Basis_K() { @@ -23,13 +23,16 @@ PW_Basis_K::~PW_Basis_K() delete[] igl2ig_k; delete[] gk2; #if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") { - if (this->precision == "single") { + if (this->device == "gpu") + { + if (this->precision == "single") + { delmem_sd_op()(this->s_kvec_c); delmem_sd_op()(this->s_gcar); delmem_sd_op()(this->s_gk2); } - else { + else + { delmem_dd_op()(this->d_gcar); delmem_dd_op()(this->d_gk2); } @@ -37,9 +40,11 @@ PW_Basis_K::~PW_Basis_K() delmem_int_op()(this->ig2ixyz_k); delmem_int_op()(this->d_igl2isz_k); } - else { + else + { #endif - if (this->precision == "single") { + if (this->precision == "single") + { delmem_sh_op()(this->s_kvec_c); delmem_sh_op()(this->s_gcar); delmem_sh_op()(this->s_gk2); @@ -50,68 +55,81 @@ PW_Basis_K::~PW_Basis_K() #endif } -void PW_Basis_K:: initparameters( - const bool gamma_only_in, - const double gk_ecut_in, - const int nks_in, //number of k points in this pool - const ModuleBase::Vector3 *kvec_d_in, // Direct coordinates of k points - const int distribution_type_in, - const bool xprime_in -) +void PW_Basis_K::initparameters(const bool gamma_only_in, + const double gk_ecut_in, + const int nks_in, // number of k points in this pool + const ModuleBase::Vector3* kvec_d_in, // Direct coordinates of k points + const int distribution_type_in, + const bool xprime_in) { this->nks = nks_in; - delete[] this->kvec_d; this->kvec_d = new ModuleBase::Vector3 [nks]; - delete[] this->kvec_c; this->kvec_c = new ModuleBase::Vector3 [nks]; + delete[] this->kvec_d; + this->kvec_d = new ModuleBase::Vector3[nks]; + delete[] this->kvec_c; + this->kvec_c = new ModuleBase::Vector3[nks]; double kmaxmod = 0; - for(int ik = 0 ; ik < this->nks ; ++ik) + for (int ik = 0; ik < this->nks; ++ik) { this->kvec_d[ik] = kvec_d_in[ik]; this->kvec_c[ik] = this->kvec_d[ik] * this->G; double kmod = sqrt(this->kvec_c[ik] * this->kvec_c[ik]); - if(kmod > kmaxmod) { kmaxmod = kmod; -} + if (kmod > kmaxmod) + { + kmaxmod = kmod; + } } - this->gk_ecut = gk_ecut_in/this->tpiba2; + this->gk_ecut = gk_ecut_in / this->tpiba2; this->ggecut = pow(sqrt(this->gk_ecut) + kmaxmod, 2); - if(this->ggecut > this->gridecut_lat) + if (this->ggecut > this->gridecut_lat) { this->ggecut = this->gridecut_lat; - this->gk_ecut = pow(sqrt(this->ggecut) - kmaxmod ,2); + this->gk_ecut = pow(sqrt(this->ggecut) - kmaxmod, 2); } this->gamma_only = gamma_only_in; - if(kmaxmod > 0) { this->gamma_only = false; //if it is not the gamma point, we do not use gamma_only -} + if (kmaxmod > 0) + { + this->gamma_only = false; // if it is not the gamma point, we do not use gamma_only + } this->xprime = xprime_in; this->fftny = this->ny; this->fftnx = this->nx; - if (this->gamma_only) + if (this->gamma_only) { - if(this->xprime) { this->fftnx = int(this->nx / 2) + 1; - } else { this->fftny = int(this->ny / 2) + 1; -} + if (this->xprime) + { + this->fftnx = int(this->nx / 2) + 1; + } + else + { + this->fftny = int(this->ny / 2) + 1; + } } this->fftnz = this->nz; this->fftnxy = this->fftnx * this->fftny; this->fftnxyz = this->fftnxy * this->fftnz; this->distribution_type = distribution_type_in; #if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") { - if (this->precision == "single") { + if (this->device == "gpu") + { + if (this->precision == "single") + { resmem_sd_op()(this->s_kvec_c, this->nks * 3); - castmem_d2s_h2d_op()(this->s_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); + castmem_d2s_h2d_op()(this->s_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); } resmem_dd_op()(this->d_kvec_c, this->nks * 3); - syncmem_d2d_h2d_op()(this->d_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); + syncmem_d2d_h2d_op()(this->d_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); } - else { + else + { #endif - if (this->precision == "single") { + if (this->precision == "single") + { resmem_sh_op()(this->s_kvec_c, this->nks * 3); - castmem_d2s_h2h_op()(this->s_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); + castmem_d2s_h2h_op()(this->s_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); } - this->d_kvec_c = reinterpret_cast(&this->kvec_c[0][0]); + this->d_kvec_c = reinterpret_cast(&this->kvec_c[0][0]); // There's no need to allocate double pointers while in a CPU environment. #if defined(__CUDA) || defined(__ROCM) } @@ -120,50 +138,59 @@ void PW_Basis_K:: initparameters( void PW_Basis_K::setupIndGk() { - //count npwk + // count npwk this->npwk_max = 0; - delete[] this->npwk; this->npwk = new int [this->nks]; + delete[] this->npwk; + this->npwk = new int[this->nks]; for (int ik = 0; ik < this->nks; ik++) { int ng = 0; - for (int ig = 0; ig < this->npw ; ig++) + for (int ig = 0; ig < this->npw; ig++) { - const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); + const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); if (gk2 <= this->gk_ecut) { ++ng; } } this->npwk[ik] = ng; - ModuleBase::CHECK_WARNING_QUIT((ng == 0), "pw_basis_k.cpp", PARAM.inp.calculation,"Current core has no plane waves! Please reduce the cores."); - if ( this->npwk_max < ng) + ModuleBase::CHECK_WARNING_QUIT((ng == 0), + "pw_basis_k.cpp", + PARAM.inp.calculation, + "Current core has no plane waves! Please reduce the cores."); + if (this->npwk_max < ng) { this->npwk_max = ng; } } - - //get igl2isz_k and igl2ig_k - if(this->npwk_max <= 0) { return;} + // get igl2isz_k and igl2ig_k + if (this->npwk_max <= 0) + { + return; + } - delete[] igl2isz_k; this->igl2isz_k = new int [this->nks * this->npwk_max]; - delete[] igl2ig_k; this->igl2ig_k = new int [this->nks * this->npwk_max]; + delete[] igl2isz_k; + this->igl2isz_k = new int[this->nks * this->npwk_max]; + delete[] igl2ig_k; + this->igl2ig_k = new int[this->nks * this->npwk_max]; for (int ik = 0; ik < this->nks; ik++) { int igl = 0; - for (int ig = 0; ig < this->npw ; ig++) + for (int ig = 0; ig < this->npw; ig++) { - const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); + const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); if (gk2 <= this->gk_ecut) { - this->igl2isz_k[ik*npwk_max + igl] = this->ig2isz[ig]; - this->igl2ig_k[ik*npwk_max + igl] = ig; + this->igl2isz_k[ik * npwk_max + igl] = this->ig2isz[ig]; + this->igl2ig_k[ik * npwk_max + igl] = ig; ++igl; } } } #if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") { + if (this->device == "gpu") + { resmem_int_op()(this->d_igl2isz_k, this->npwk_max * this->nks); syncmem_int_h2d_op()(this->d_igl2isz_k, this->igl2isz_k, this->npwk_max * this->nks); } @@ -172,7 +199,7 @@ void PW_Basis_K::setupIndGk() return; } -/// +/// /// distribute plane wave basis and real-space grids to different processors /// set up maps for fft and create arrays for MPI_Alltoall /// set up ffts @@ -185,15 +212,36 @@ void PW_Basis_K::setuptransform() this->getstartgr(); this->setupIndGk(); this->fft_bundle.clear(); - #if defined(__DSP) - this->fft_bundle.setfft("dsp",this->precision); - #else - this->fft_bundle.setfft(this->device,this->precision); - #endif - if(this->xprime){ - this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); - }else{ - this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); +#if defined(__DSP) + this->fft_bundle.setfft("dsp", this->precision); +#else + this->fft_bundle.setfft(this->device, this->precision); +#endif + if (this->xprime) + { + this->fft_bundle.initfft(this->nx, + this->ny, + this->nz, + this->lix, + this->rix, + this->nst, + this->nplane, + this->poolnproc, + this->gamma_only, + this->xprime); + } + else + { + this->fft_bundle.initfft(this->nx, + this->ny, + this->nz, + this->liy, + this->riy, + this->nst, + this->nplane, + this->poolnproc, + this->gamma_only, + this->xprime); } this->fft_bundle.setupFFT(); ModuleBase::timer::tick(this->classname, "setuptransform"); @@ -204,8 +252,10 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h this->erf_ecut = erf_ecut_in; this->erf_height = erf_height_in; this->erf_sigma = erf_sigma_in; - if(this->npwk_max <= 0) { return; -} + if (this->npwk_max <= 0) + { + return; + } delete[] gk2; delete[] gcar; this->gk2 = new double[this->npwk_max * this->nks]; @@ -214,10 +264,10 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h ModuleBase::Memory::record("PW_B_K::gcar", sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks); ModuleBase::Vector3 f; - for(int ik = 0 ; ik < this->nks ; ++ik) + for (int ik = 0; ik < this->nks; ++ik) { ModuleBase::Vector3 kv = this->kvec_d[ik]; - for(int igl = 0 ; igl < this-> npwk[ik] ; ++igl) + for (int igl = 0; igl < this->npwk[ik]; ++igl) { int isz = this->igl2isz_k[ik * npwk_max + igl]; int iz = isz % this->nz; @@ -225,12 +275,18 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h int ixy = this->is2fftixy[is]; int ix = ixy / this->fftny; int iy = ixy % this->fftny; - if (ix >= int(this->nx/2) + 1) { ix -= this->nx; -} - if (iy >= int(this->ny/2) + 1) { iy -= this->ny; -} - if (iz >= int(this->nz/2) + 1) { iz -= this->nz; -} + if (ix >= int(this->nx / 2) + 1) + { + ix -= this->nx; + } + if (iy >= int(this->ny / 2) + 1) + { + iy -= this->ny; + } + if (iz >= int(this->nz / 2) + 1) + { + iz -= this->nz; + } f.x = ix; f.y = iy; f.z = iz; @@ -249,30 +305,42 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h } } #if defined(__CUDA) || defined(__ROCM) - if (this->device == "gpu") { - if (this->precision == "single") { + if (this->device == "gpu") + { + if (this->precision == "single") + { resmem_sd_op()(this->s_gk2, this->npwk_max * this->nks); resmem_sd_op()(this->s_gcar, this->npwk_max * this->nks * 3); castmem_d2s_h2d_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); - castmem_d2s_h2d_op()(this->s_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); + castmem_d2s_h2d_op()(this->s_gcar, + reinterpret_cast(&this->gcar[0][0]), + this->npwk_max * this->nks * 3); } - else { + else + { resmem_dd_op()(this->d_gk2, this->npwk_max * this->nks); resmem_dd_op()(this->d_gcar, this->npwk_max * this->nks * 3); syncmem_d2d_h2d_op()(this->d_gk2, this->gk2, this->npwk_max * this->nks); - syncmem_d2d_h2d_op()(this->d_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); + syncmem_d2d_h2d_op()(this->d_gcar, + reinterpret_cast(&this->gcar[0][0]), + this->npwk_max * this->nks * 3); } } - else { + else + { #endif - if (this->precision == "single") { + if (this->precision == "single") + { resmem_sh_op()(this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2"); resmem_sh_op()(this->s_gcar, this->npwk_max * this->nks * 3, "PW_B_K::s_gcar"); castmem_d2s_h2h_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); - castmem_d2s_h2h_op()(this->s_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); + castmem_d2s_h2h_op()(this->s_gcar, + reinterpret_cast(&this->gcar[0][0]), + this->npwk_max * this->nks * 3); } - else { - this->d_gcar = reinterpret_cast(&this->gcar[0][0]); + else + { + this->d_gcar = reinterpret_cast(&this->gcar[0][0]); this->d_gk2 = this->gk2; } // There's no need to allocate double pointers while in a CPU environment. @@ -281,18 +349,25 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h #endif } -ModuleBase::Vector3 PW_Basis_K:: cal_GplusK_cartesian(const int ik, const int ig) const { +ModuleBase::Vector3 PW_Basis_K::cal_GplusK_cartesian(const int ik, const int ig) const +{ int isz = this->ig2isz[ig]; int iz = isz % this->nz; int is = isz / this->nz; int ix = this->is2fftixy[is] / this->fftny; int iy = this->is2fftixy[is] % this->fftny; - if (ix >= int(this->nx/2) + 1) { ix -= this->nx; -} - if (iy >= int(this->ny/2) + 1) { iy -= this->ny; -} - if (iz >= int(this->nz/2) + 1) { iz -= this->nz; -} + if (ix >= int(this->nx / 2) + 1) + { + ix -= this->nx; + } + if (iy >= int(this->ny / 2) + 1) + { + iy -= this->ny; + } + if (iz >= int(this->nz / 2) + 1) + { + iz -= this->nz; + } ModuleBase::Vector3 f; f.x = ix; f.y = iy; @@ -321,36 +396,34 @@ ModuleBase::Vector3 PW_Basis_K::getgdirect(const int ik, const int igl) return f; } - ModuleBase::Vector3 PW_Basis_K::getgpluskcar(const int ik, const int igl) const { - return this->gcar[ik * this->npwk_max + igl]+this->kvec_c[ik]; + return this->gcar[ik * this->npwk_max + igl] + this->kvec_c[ik]; } int& PW_Basis_K::getigl2isz(const int ik, const int igl) const { - return this->igl2isz_k[ik*this->npwk_max + igl]; + return this->igl2isz_k[ik * this->npwk_max + igl]; } int& PW_Basis_K::getigl2ig(const int ik, const int igl) const { - return this->igl2ig_k[ik*this->npwk_max + igl]; + return this->igl2ig_k[ik * this->npwk_max + igl]; } - void PW_Basis_K::get_ig2ixyz_k() { - #if not defined(__DSP) +#if not defined(__DSP) if (this->device != "gpu") { - //only GPU need to get ig2ixyz_k + // only GPU need to get ig2ixyz_k return; } - #endif +#endif ig2ixyz_k_cpu.resize(this->npwk_max * this->nks); ModuleBase::Memory::record("PW_B_K::ig2ixyz", sizeof(int) * this->npwk_max * this->nks); - assert(gamma_only == false); //We only finish non-gamma_only fft on GPU temperarily. - for(int ik = 0; ik < this->nks; ++ik) + assert(gamma_only == false); // We only finish non-gamma_only fft on GPU temperarily. + for (int ik = 0; ik < this->nks; ++ik) { - for(int igl = 0; igl < this->npwk[ik]; ++igl) + for (int igl = 0; igl < this->npwk[ik]; ++igl) { int isz = this->igl2isz_k[igl + ik * npwk_max]; int iz = isz % this->nz; @@ -370,14 +443,16 @@ std::vector PW_Basis_K::get_ig2ix(const int ik) const std::vector ig_to_ix; ig_to_ix.resize(npwk[ik]); - for(int ig = 0; ig < npwk[ik]; ig++) + for (int ig = 0; ig < npwk[ik]; ig++) { int isz = this->igl2isz_k[ig + ik * npwk_max]; int is = isz / this->nz; int ixy = this->is2fftixy[is]; int ix = ixy / this->ny; - if (ix < (nx / 2) + 1) { ix += nx; -} + if (ix < (nx / 2) + 1) + { + ix += nx; + } ig_to_ix[ig] = ix; } return ig_to_ix; @@ -388,14 +463,16 @@ std::vector PW_Basis_K::get_ig2iy(const int ik) const std::vector ig_to_iy; ig_to_iy.resize(npwk[ik]); - for(int ig = 0; ig < npwk[ik]; ig++) + for (int ig = 0; ig < npwk[ik]; ig++) { int isz = this->igl2isz_k[ig + ik * npwk_max]; int is = isz / this->nz; int ixy = this->is2fftixy[is]; int iy = ixy % this->ny; - if (iy < (ny / 2) + 1) { iy += ny; -} + if (iy < (ny / 2) + 1) + { + iy += ny; + } ig_to_iy[ig] = iy; } return ig_to_iy; @@ -406,42 +483,50 @@ std::vector PW_Basis_K::get_ig2iz(const int ik) const std::vector ig_to_iz; ig_to_iz.resize(npwk[ik]); - for(int ig = 0; ig < npwk[ik]; ig++) + for (int ig = 0; ig < npwk[ik]; ig++) { int isz = this->igl2isz_k[ig + ik * npwk_max]; int iz = isz % this->nz; - if (iz < (nz / 2) + 1) { iz += nz; -} + if (iz < (nz / 2) + 1) + { + iz += nz; + } ig_to_iz[ig] = iz; } return ig_to_iz; } template <> -float * PW_Basis_K::get_kvec_c_data() const { +float* PW_Basis_K::get_kvec_c_data() const +{ return this->s_kvec_c; } template <> -double * PW_Basis_K::get_kvec_c_data() const { +double* PW_Basis_K::get_kvec_c_data() const +{ return this->d_kvec_c; } template <> -float * PW_Basis_K::get_gcar_data() const { +float* PW_Basis_K::get_gcar_data() const +{ return this->s_gcar; } template <> -double * PW_Basis_K::get_gcar_data() const { +double* PW_Basis_K::get_gcar_data() const +{ return this->d_gcar; } template <> -float * PW_Basis_K::get_gk2_data() const { +float* PW_Basis_K::get_gk2_data() const +{ return this->s_gk2; } template <> -double * PW_Basis_K::get_gk2_data() const { +double* PW_Basis_K::get_gk2_data() const +{ return this->d_gk2; } -} // namespace ModulePW \ No newline at end of file +} // namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp index 59485d599a..052d975b21 100644 --- a/source/module_basis/module_pw/pw_transform_k_dsp.cpp +++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp @@ -2,140 +2,130 @@ #include "module_basis/module_pw/kernels/pw_op.h" #include "pw_basis_k.h" #include "pw_gatherscatter.h" + #include #include - +#if defined (__DSP) namespace ModulePW { - template - void PW_Basis_K::real2recip_dsp(const std::complex* in, - std::complex* out, - const int ik, - const bool add , - const FPTYPE factor ) const - { - const base_device::DEVICE_CPU* ctx; - const base_device::DEVICE_GPU* gpux; - assert(this->gamma_only==false); - auto* auxr = this->fft_bundle.get_auxr_3d_data(); - - const int startig = ik * this->npwk_max; - const int npw_k = this->npwk[ik]; - // copy the in into the auxr with complex - memcpy(auxr,in,this->nrxx*2*8); +template +void PW_Basis_K::real2recip_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add, + const FPTYPE factor) const +{ + const base_device::DEVICE_CPU* ctx; + const base_device::DEVICE_GPU* gpux; + assert(this->gamma_only == false); + auto* auxr = this->fft_bundle.get_auxr_3d_data(); - // 3d fft - this->fft_bundle.fft3D_forward(gpux, - auxr, - auxr); + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + // copy the in into the auxr with complex + memcpy(auxr, in, this->nrxx * 2 * 8); - // copy the result from the auxr to the out ,while consider the add - set_real_to_recip_output_op()(ctx, - npw_k, - this->nxyz, - add, - factor, - this->ig2ixyz_k_cpu.data() + startig, - auxr, - out); - } - template - void PW_Basis_K::recip2real_dsp(const std::complex* in, - std::complex* out, - const int ik, - const bool add , - const FPTYPE factor ) const - { - assert(this->gamma_only == false); - const base_device::DEVICE_CPU* ctx; - const base_device::DEVICE_GPU* gpux; - // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz - auto * auxr = this->fft_bundle.get_auxr_3d_data(); - memset(auxr,0,this->nxyz*2*8); + // 3d fft + this->fft_bundle.fft3D_forward(gpux, auxr, auxr); - const int startig = ik * this->npwk_max; - const int npw_k = this->npwk[ik]; - //copy the mapping form the type of stick to the 3dfft - set_3d_fft_box_op() - ( - ctx,npw_k,this->ig2ixyz_k_cpu.data()+startig,in,auxr - ); - // use 3d fft backward - this->fft_bundle.fft3D_backward(gpux,auxr,auxr); - if(add) - { - const int one =1; - const std::complex factor1=std::complex(factor,0); - zaxpy_(&nrxx,&factor1,auxr,&one,out,&one); - } - else - { - memcpy(out,auxr,nrxx*2*8); - } + // copy the result from the auxr to the out ,while consider the add + set_real_to_recip_output_op()(ctx, + npw_k, + this->nxyz, + add, + factor, + this->ig2ixyz_k_cpu.data() + startig, + auxr, + out); +} +template +void PW_Basis_K::recip2real_dsp(const std::complex* in, + std::complex* out, + const int ik, + const bool add, + const FPTYPE factor) const +{ + assert(this->gamma_only == false); + const base_device::DEVICE_CPU* ctx; + const base_device::DEVICE_GPU* gpux; + // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz + auto* auxr = this->fft_bundle.get_auxr_3d_data(); + memset(auxr, 0, this->nxyz * 2 * 8); + + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + // copy the mapping form the type of stick to the 3dfft + set_3d_fft_box_op()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr); + // use 3d fft backward + this->fft_bundle.fft3D_backward(gpux, auxr, auxr); + if (add) + { + const int one = 1; + const std::complex factor1 = std::complex(factor, 0); + zaxpy_(&nrxx, &factor1, auxr, &one, out, &one); } - template <> - void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, - const int ik, - const int size, - const std::complex* input, - const float* input1, - std::complex* output, - const bool add , - const float factor ) const + else { - + memcpy(out, auxr, nrxx * 2 * 8); } +} +template <> +void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, + const int ik, + const int size, + const std::complex* input, + const float* input1, + std::complex* output, + const bool add, + const float factor) const +{ +} - template <> - void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, - const int ik, - const int size, - const std::complex* input, - const double* input1, - std::complex* output, - const bool add , - const double factor ) const - { - ModuleBase::timer::tick(this->classname,"convolution"); +template <> +void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, + const int ik, + const int size, + const std::complex* input, + const double* input1, + std::complex* output, + const bool add, + const double factor) const +{ + ModuleBase::timer::tick(this->classname, "convolution"); - assert(this->gamma_only == false); - const base_device::DEVICE_GPU* gpux; - // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz - auto * auxr = this->fft_bundle.get_auxr_3d_data(); - memset(auxr,0,this->nxyz*2*8); - const int startig = ik * this->npwk_max; - const int npw_k = this->npwk[ik]; - - //copy the mapping form the type of stick to the 3dfft - set_3d_fft_box_op() - ( - ctx,npw_k,this->ig2ixyz_k_cpu.data()+startig,input,auxr - ); + assert(this->gamma_only == false); + const base_device::DEVICE_GPU* gpux; + // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz + auto* auxr = this->fft_bundle.get_auxr_3d_data(); + memset(auxr, 0, this->nxyz * 2 * 8); + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; - // use 3d fft backward - this->fft_bundle.fft3D_backward(gpux,auxr,auxr); + // copy the mapping form the type of stick to the 3dfft + set_3d_fft_box_op()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr); - for (int ir=0;irfft_bundle.fft3D_backward(gpux, auxr, auxr); - // 3d fft - this->fft_bundle.fft3D_forward(gpux, - auxr, - auxr); - // copy the result from the auxr to the out ,while consider the add - set_real_to_recip_output_op()(ctx, - npw_k, - this->nxyz, - add, - factor, - this->ig2ixyz_k_cpu.data() + startig, - auxr, - output); - ModuleBase::timer::tick(this->classname,"convolution"); + for (int ir = 0; ir < size; ir++) + { + auxr[ir] *= input1[ir]; } - + + // 3d fft + this->fft_bundle.fft3D_forward(gpux, auxr, auxr); + // copy the result from the auxr to the out ,while consider the add + set_real_to_recip_output_op()(ctx, + npw_k, + this->nxyz, + add, + factor, + this->ig2ixyz_k_cpu.data() + startig, + auxr, + output); + ModuleBase::timer::tick(this->classname, "convolution"); +} + // template void PW_Basis_K::real2recip_dsp(const std::complex* in, // std::complex* out, // const int ik, @@ -148,13 +138,14 @@ namespace ModulePW // const float factor) const; // in:(nz, ns) ; out(nplane,nx*ny) template void PW_Basis_K::real2recip_dsp(const std::complex* in, - std::complex* out, - const int ik, - const bool add, - const double factor) const; // in:(nplane,nx*ny) ; out(nz, ns) + std::complex* out, + const int ik, + const bool add, + const double factor) const; // in:(nplane,nx*ny) ; out(nz, ns) template void PW_Basis_K::recip2real_dsp(const std::complex* in, - std::complex* out, - const int ik, - const bool add, - const double factor) const; -} + std::complex* out, + const int ik, + const bool add, + const double factor) const; +} // namespace ModulePW +#endif diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 0961675029..74d3904a65 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -81,7 +81,7 @@ ESolver_KS_PW::ESolver_KS_PW() #endif #ifdef __DSP std::cout << " ** Initializing DSP Hardware..." << std::endl; - dspInitHandle(GlobalV::MY_RANK); + mtfunc::dspInitHandle(GlobalV::MY_RANK); #endif } @@ -109,7 +109,7 @@ ESolver_KS_PW::~ESolver_KS_PW() } #ifdef __DSP std::cout << " ** Closing DSP Hardware..." << std::endl; - dspDestoryHandle(GlobalV::MY_RANK); + mtfunc::dspDestoryHandle(GlobalV::MY_RANK); #endif if(PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") { diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index 177e68847c..2d1b747de4 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -454,7 +454,7 @@ void Diago_DavSubspace::cal_elem(const int& dim, { #ifdef __DSP // Only on dsp hardware need an extra space to reduce data - dsp_dav_subspace_reduce(hcc, scc, nbase, this->nbase_x, this->notconv, this->diag_comm.comm); + mtfunc::dsp_dav_subspace_reduce(hcc, scc, nbase, this->nbase_x, this->notconv, this->diag_comm.comm); #else auto* swap = new T[notconv * this->nbase_x]; From 7cc54697c5ca4b18f6a498c583a98c2e69d29f93 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Tue, 25 Feb 2025 10:28:28 +0800 Subject: [PATCH 12/14] remove mutable --- source/module_basis/module_pw/module_fft/fft_dsp.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.h b/source/module_basis/module_pw/module_fft/fft_dsp.h index 6cadef975f..0cdfe84fc6 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.h +++ b/source/module_basis/module_pw/module_fft/fft_dsp.h @@ -65,17 +65,16 @@ class FFT_DSP : public FFT_BASE void fft3D_backward(std::complex* in, std::complex* out) const override; public: - int nxyz; + int nxyz=0; INT cluster_id=0; - mutable INT b_id; + mutable INT b_id=0; mutable INT thread_id_for=0; PLAN* ptr_plan_forward=nullptr; PLAN* ptr_plan_backward=nullptr; mutable unsigned long args_for[2]; mutable unsigned long args_back[2]; - mutable E * forward_in; - mutable E * convert2; - std::complex* c_auxr_3d = nullptr; // fft space + E * forward_in=nullptr; + std::complex* c_auxr_3d = nullptr; // fft space std::complex* z_auxr_3d = nullptr; // fft space }; From 8c181701e8c76f07f8809bbb3299ccbaf4b64b7a Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Tue, 25 Feb 2025 20:24:14 +0800 Subject: [PATCH 13/14] fix fft_dsp --- source/module_basis/module_pw/module_fft/fft_dsp.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.cpp b/source/module_basis/module_pw/module_fft/fft_dsp.cpp index 3428af345f..fac287fc49 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.cpp +++ b/source/module_basis/module_pw/module_fft/fft_dsp.cpp @@ -24,14 +24,14 @@ void FFT_DSP::setupFFT() PLAN* ptr_plan_forward; PLAN* ptr_plan_backward; INT num_thread = 8; - INT size; + INT size=0; hthread_dat_load(cluster_id, FFT_DAT_DIR); // compute the size of and malloc thread size = nx * ny * nz * 2 * sizeof(E); forward_in = (E*)hthread_malloc((int)cluster_id, size, HT_MEM_RW); - // // //init 3d fft problem + //init 3d fft problem pbm_forward.num_dim = 3; pbm_forward.n[0] = nx; pbm_forward.n[1] = ny; @@ -40,7 +40,7 @@ void FFT_DSP::setupFFT() pbm_forward.in = forward_in; pbm_forward.out = forward_in; - // // //make ptr plan + //make ptr plan make_plan(&pbm_forward, &ptr_plan_forward, cluster_id, num_thread); ptr_plan_forward->in = forward_in; ptr_plan_forward->out = forward_in; From b89678062cc0b899c9116d88de33b9bb07876cbd Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Tue, 25 Feb 2025 21:48:54 +0800 Subject: [PATCH 14/14] add the convolution and allocate or destroy the b_id --- .../module_pw/module_fft/fft_base.h | 148 ++++++++---------- .../module_pw/module_fft/fft_bundle.cpp | 14 ++ .../module_pw/module_fft/fft_bundle.h | 1 + .../module_pw/module_fft/fft_dsp.cpp | 33 ++-- .../module_pw/module_fft/fft_dsp_float.cpp | 5 + .../module_pw/pw_transform_k_dsp.cpp | 9 +- .../hamilt_pwdft/operator_pw/veff_pw.cpp | 14 +- 7 files changed, 124 insertions(+), 100 deletions(-) diff --git a/source/module_basis/module_pw/module_fft/fft_base.h b/source/module_basis/module_pw/module_fft/fft_base.h index b64b6f4e00..b7c63fc9b1 100644 --- a/source/module_basis/module_pw/module_fft/fft_base.h +++ b/source/module_basis/module_pw/module_fft/fft_base.h @@ -7,166 +7,150 @@ namespace ModulePW template class FFT_BASE { -public: + public: + FFT_BASE() {}; + virtual ~FFT_BASE() {}; - FFT_BASE(){}; - virtual ~FFT_BASE(){}; - /** * @brief Initialize the fft parameters As virtual function. - * + * * The function is used to initialize the fft parameters. */ - virtual __attribute__((weak)) - void initfft(int nx_in, - int ny_in, - int nz_in, - int lixy_in, - int rixy_in, - int ns_in, - int nplane_in, - int nproc_in, - bool gamma_only_in, - bool xprime_in = true); - - virtual __attribute__((weak)) - void initfft(int nx_in, - int ny_in, - int nz_in); + virtual __attribute__((weak)) void initfft(int nx_in, + int ny_in, + int nz_in, + int lixy_in, + int rixy_in, + int ns_in, + int nplane_in, + int nproc_in, + bool gamma_only_in, + bool xprime_in = true); + + virtual __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in); /** * @brief Setup the fft Plan and data As pure virtual function. - * + * * The function is set as pure virtual function.In order to * override the function in the derived class.In the derived * class, the function is used to setup the fft Plan and data. */ - virtual void setupFFT()=0; + virtual void setupFFT() = 0; /** * @brief Clean the fft Plan As pure virtual function. - * + * * The function is set as pure virtual function.In order to * override the function in the derived class.In the derived * class, the function is used to clean the fft Plan. */ - virtual void cleanFFT()=0; - + virtual void cleanFFT() = 0; + /** * @brief Clear the fft data As pure virtual function. - * + * * The function is set as pure virtual function.In order to * override the function in the derived class.In the derived * class, the function is used to clear the fft data. */ - virtual void clear()=0; - + virtual void clear() = 0; + + virtual void resource_handler(const int flag) const {}; /** * @brief Get the real space data in cpu-like fft - * + * * The function is used to get the real space data.While the * FFT_BASE is an abstract class,the function will be override, - * The attribute weak is used to avoid define the function. + * The attribute weak is used to avoid define the function. */ - virtual __attribute__((weak)) - FPTYPE* get_rspace_data() const; + virtual __attribute__((weak)) FPTYPE* get_rspace_data() const; - virtual __attribute__((weak)) - std::complex* get_auxr_data() const; + virtual __attribute__((weak)) std::complex* get_auxr_data() const; - virtual __attribute__((weak)) - std::complex* get_auxg_data() const; + virtual __attribute__((weak)) std::complex* get_auxg_data() const; /** * @brief Get the auxiliary real space data in 3D - * + * * The function is used to get the auxiliary real space data in 3D. * While the FFT_BASE is an abstract class,the function will be override, * The attribute weak is used to avoid define the function. */ - virtual __attribute__((weak)) - std::complex* get_auxr_3d_data() const; + virtual __attribute__((weak)) std::complex* get_auxr_3d_data() const; - //forward fft in x-y direction + // forward fft in x-y direction /** * @brief Forward FFT in x-y direction * @param in input data * @param out output data - * + * * This function performs the forward FFT in the x-y direction. * It involves two axes, x and y. The FFT is applied multiple times - * along the left and right boundaries in the primary direction(which is - * determined by the xprime flag).Notably, the Y axis operates in + * along the left and right boundaries in the primary direction(which is + * determined by the xprime flag).Notably, the Y axis operates in * "many-many-FFT" mode. */ - virtual __attribute__((weak)) - void fftxyfor(std::complex* in, - std::complex* out) const; + virtual __attribute__((weak)) void fftxyfor(std::complex* in, + std::complex* out) const; - virtual __attribute__((weak)) - void fftxybac(std::complex* in, - std::complex* out) const; + virtual __attribute__((weak)) void fftxybac(std::complex* in, + std::complex* out) const; /** * @brief Forward FFT in z direction * @param in input data * @param out output data - * + * * This function performs the forward FFT in the z direction. * It involves only one axis, z. The FFT is applied only once. * Notably, the Z axis operates in many FFT with nz*ns. */ - virtual __attribute__((weak)) - void fftzfor(std::complex* in, - std::complex* out) const; - - virtual __attribute__((weak)) - void fftzbac(std::complex* in, - std::complex* out) const; + virtual __attribute__((weak)) void fftzfor(std::complex* in, + std::complex* out) const; + + virtual __attribute__((weak)) void fftzbac(std::complex* in, + std::complex* out) const; /** * @brief Forward FFT in x-y direction with real to complex * @param in input data, real type * @param out output data, complex type - * - * This function performs the forward FFT in the x-y direction + * + * This function performs the forward FFT in the x-y direction * with real to complex.There is no difference between fftxyfor. */ - virtual __attribute__((weak)) - void fftxyr2c(FPTYPE* in, - std::complex* out) const; - - virtual __attribute__((weak)) - void fftxyc2r(std::complex* in, - FPTYPE* out) const; - + virtual __attribute__((weak)) void fftxyr2c(FPTYPE* in, + std::complex* out) const; + + virtual __attribute__((weak)) void fftxyc2r(std::complex* in, + FPTYPE* out) const; + /** * @brief Forward FFT in 3D * @param in input data * @param out output data - * + * * This function performs the forward FFT for gpu-like fft. * It involves three axes, x, y, and z. The FFT is applied multiple times * for fft3D_forward. */ - virtual __attribute__((weak)) - void fft3D_forward(std::complex* in, - std::complex* out) const; - - virtual __attribute__((weak)) - void fft3D_backward(std::complex* in, - std::complex* out) const; - -protected: - int nx=0; - int ny=0; - int nz=0; + virtual __attribute__((weak)) void fft3D_forward(std::complex* in, + std::complex* out) const; + + virtual __attribute__((weak)) void fft3D_backward(std::complex* in, + std::complex* out) const; + + protected: + int nx = 0; + int ny = 0; + int nz = 0; }; template FFT_BASE::FFT_BASE(); template FFT_BASE::FFT_BASE(); template FFT_BASE::~FFT_BASE(); template FFT_BASE::~FFT_BASE(); -} +} // namespace ModulePW #endif // FFT_BASE_H diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 83c2d02466..7289e8ab02 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -146,6 +146,20 @@ void FFT_Bundle::clear() } } +void FFT_Bundle::resource_handler(const int flag) const +{ + if (this->device=="dsp") + { + if (double_flag) + { + fft_double->resource_handler(flag); + } + if (float_flag) + { + fft_float->resource_handler(flag); + } + } +} template <> void FFT_Bundle::fftxyfor(std::complex* in, std::complex* out) const { diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 58851e139d..1982a79a0c 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -81,6 +81,7 @@ class FFT_Bundle void clear(); + void resource_handler(const int flag) const; /** * @brief Get the real space data. * @return FPTYPE* the real space data. diff --git a/source/module_basis/module_pw/module_fft/fft_dsp.cpp b/source/module_basis/module_pw/module_fft/fft_dsp.cpp index fac287fc49..0247ac84a7 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp.cpp +++ b/source/module_basis/module_pw/module_fft/fft_dsp.cpp @@ -60,33 +60,36 @@ void FFT_DSP::setupFFT() ptr_plan_backward->out = forward_in; args_back[1] = (unsigned long)ptr_plan_backward; } - +template <> +void FFT_DSP::resource_handler(const int flag) const +{ + if (flag==0) + { + hthread_barrier_destroy(b_id); + hthread_group_destroy(thread_id_for); + } + else if (flag==1) + { + INT num_thread = 8; + thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL); + // create b_id for the barrier + b_id = hthread_barrier_create(cluster_id); + args_for[0] = b_id; + args_back[0] = b_id; + } +} template <> void FFT_DSP::fft3D_forward(std::complex* in, std::complex* out) const { - INT num_thread = 8; - thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL); - // create b_id for the barrier - b_id = hthread_barrier_create(cluster_id); - args_for[0] = b_id; hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for); hthread_group_wait(thread_id_for); - hthread_barrier_destroy(b_id); - hthread_group_destroy(thread_id_for); } template <> void FFT_DSP::fft3D_backward(std::complex* in, std::complex* out) const { - INT num_thread = 8; - thread_id_for = hthread_group_create(cluster_id, num_thread, NULL, 0, 0, NULL); - // create b_id for the barrier - b_id = hthread_barrier_create(cluster_id); - args_back[0] = b_id; hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back); hthread_group_wait(thread_id_for); - hthread_barrier_destroy(b_id); - hthread_group_destroy(thread_id_for); } template <> void FFT_DSP::cleanFFT() diff --git a/source/module_basis/module_pw/module_fft/fft_dsp_float.cpp b/source/module_basis/module_pw/module_fft/fft_dsp_float.cpp index 2a17bacd02..3c11cfc81f 100644 --- a/source/module_basis/module_pw/module_fft/fft_dsp_float.cpp +++ b/source/module_basis/module_pw/module_fft/fft_dsp_float.cpp @@ -16,5 +16,10 @@ template<> void FFT_DSP::cleanFFT() { +} +template<> +void FFT_DSP::resource_handler(const int flag) const +{ + } } \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp index 052d975b21..b292e25f0a 100644 --- a/source/module_basis/module_pw/pw_transform_k_dsp.cpp +++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp @@ -26,8 +26,11 @@ void PW_Basis_K::real2recip_dsp(const std::complex* in, memcpy(auxr, in, this->nrxx * 2 * 8); // 3d fft - this->fft_bundle.fft3D_forward(gpux, auxr, auxr); - + this->fft_bundle.resource_handler(1); + this->fft_bundle.fft3D_forward(gpux, + auxr, + auxr); + this->fft_bundle.resource_handler(0); // copy the result from the auxr to the out ,while consider the add set_real_to_recip_output_op()(ctx, npw_k, @@ -57,7 +60,9 @@ void PW_Basis_K::recip2real_dsp(const std::complex* in, // copy the mapping form the type of stick to the 3dfft set_3d_fft_box_op()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr); // use 3d fft backward + this->fft_bundle.resource_handler(1); this->fft_bundle.fft3D_backward(gpux, auxr, auxr); + this->fft_bundle.resource_handler(0); if (add) { const int one = 1; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp index 6bff6b2dc0..54e1a052be 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp @@ -53,7 +53,9 @@ void Veff>::act( int max_npw = nbasis / npol; const int current_spin = this->isk[this->ik]; - +#ifdef __DSP + wfcpw->fft_bundle.resource_handler(1); +#endif // T *porter = new T[wfcpw->nmaxgr]; for (int ib = 0; ib < nbands; ib += npol) { @@ -75,6 +77,13 @@ void Veff>::act( } // wfcpw->real2recip(porter, tmhpsi, this->ik, true); wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true); + // wfcpw->convolution(this->ctx, + // this->ik, + // this->veff_col, + // tmpsi_in, + // this->veff+current_spin, + // tmhpsi, + // true); } else { @@ -111,6 +120,9 @@ void Veff>::act( tmhpsi += max_npw * npol; tmpsi_in += max_npw * npol; } +#ifdef __DSP + wfcpw->fft_bundle.resource_handler(0); +#endif ModuleBase::timer::tick("Operator", "VeffPW"); }