From ccb2a1f92d29ba2d3cda94167df788229b0b5e22 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Thu, 21 Nov 2024 19:21:24 +0800 Subject: [PATCH 1/3] update the duc compile --- source/module_base/module_device/rocm/memory_op.hip.cu | 2 ++ .../module_basis/module_pw/module_fft/fft_bundle.cpp | 6 +++--- source/module_basis/module_pw/module_fft/fft_rcom.h | 10 +++++----- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/source/module_base/module_device/rocm/memory_op.hip.cu b/source/module_base/module_device/rocm/memory_op.hip.cu index 678acbc048..1909cfb771 100644 --- a/source/module_base/module_device/rocm/memory_op.hip.cu +++ b/source/module_base/module_device/rocm/memory_op.hip.cu @@ -219,6 +219,8 @@ template struct cast_memory_op, std::complex, base_device::DEVICE_GPU, base_device::DEVICE_GPU>; +template struct cast_memory_op, float, base_device::DEVICE_GPU, base_device::DEVICE_GPU>; +template struct cast_memory_op, double, base_device::DEVICE_GPU, base_device::DEVICE_GPU>; template struct cast_memory_op; template struct cast_memory_op; template struct cast_memory_op; diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 31d29c32f9..98be60dff7 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -89,14 +89,14 @@ void FFT_Bundle::initfft(int nx_in, if (device=="gpu") { #if defined(__ROCM) - fft_float = new FFT_RCOM(); + fft_float = make_unique>(); fft_float->initfft(nx_in,ny_in,nz_in); - fft_double = new FFT_RCOM(); + fft_double = make_unique>(); fft_double->initfft(nx_in,ny_in,nz_in); #elif defined(__CUDA) fft_float = make_unique>(); fft_float->initfft(nx_in,ny_in,nz_in); - fft_double = make_unique>(); + fft_double = make_unique>(); fft_double->initfft(nx_in,ny_in,nz_in); #endif } diff --git a/source/module_basis/module_pw/module_fft/fft_rcom.h b/source/module_basis/module_pw/module_fft/fft_rcom.h index 64ad13329e..09035b50ba 100644 --- a/source/module_basis/module_pw/module_fft/fft_rcom.h +++ b/source/module_basis/module_pw/module_fft/fft_rcom.h @@ -6,11 +6,11 @@ namespace ModulePW { template -class FFT_ROCM : public FFT_BASE +class FFT_RCOM : public FFT_BASE { public: - FFT_ROCM(){}; - ~FFT_ROCM(){}; + FFT_RCOM(){}; + ~FFT_RCOM(){}; void setupFFT() override; @@ -58,8 +58,8 @@ class FFT_ROCM : public FFT_BASE }; template FFT_RCOM::FFT_RCOM(); -template FFT_ROCM::~FFT_ROCM(); +template FFT_RCOM::~FFT_RCOM(); template FFT_RCOM::FFT_RCOM(); -template FFT_ROCM::~FFT_ROCM(); +template FFT_RCOM::~FFT_RCOM(); }// namespace ModulePW #endif From eee8c7bfab81cceb80b28c4342c64e38e32c67ed Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Thu, 21 Nov 2024 19:52:06 +0800 Subject: [PATCH 2/3] update the FFT_CUDA in the fft_bundle.cpp --- source/module_basis/module_pw/module_fft/fft_bundle.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 98be60dff7..b593f0999a 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -96,7 +96,7 @@ void FFT_Bundle::initfft(int nx_in, #elif defined(__CUDA) fft_float = make_unique>(); fft_float->initfft(nx_in,ny_in,nz_in); - fft_double = make_unique>(); + fft_double = make_unique>(); fft_double->initfft(nx_in,ny_in,nz_in); #endif } From 4e12e56de0a2dcb21cb22a37a76b908184e44019 Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Mon, 25 Nov 2024 10:28:12 +0800 Subject: [PATCH 3/3] update the rcom to rocm --- source/module_basis/module_pw/CMakeLists.txt | 2 +- .../module_pw/module_fft/fft_bundle.cpp | 6 ++-- .../module_fft/{fft_rcom.cpp => fft_rocm.cpp} | 32 +++++++++++-------- .../module_fft/{fft_rcom.h => fft_rocm.h} | 10 ++---- 4 files changed, 25 insertions(+), 25 deletions(-) rename source/module_basis/module_pw/module_fft/{fft_rcom.cpp => fft_rocm.cpp} (76%) rename source/module_basis/module_pw/module_fft/{fft_rcom.h => fft_rocm.h} (87%) diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index 78f9824d8b..a95eca4917 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -10,7 +10,7 @@ if (USE_CUDA) endif() if (USE_ROCM) list (APPEND FFT_SRC - module_fft/fft_rcom.cpp + module_fft/fft_rocm.cpp ) endif() diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index b593f0999a..a7be7d988d 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -6,7 +6,7 @@ #include "fft_cuda.h" #endif #if defined(__ROCM) -#include "fft_rcom.h" +#include "fft_rocm.h" #endif template @@ -89,9 +89,9 @@ void FFT_Bundle::initfft(int nx_in, if (device=="gpu") { #if defined(__ROCM) - fft_float = make_unique>(); + fft_float = make_unique>(); fft_float->initfft(nx_in,ny_in,nz_in); - fft_double = make_unique>(); + fft_double = make_unique>(); fft_double->initfft(nx_in,ny_in,nz_in); #elif defined(__CUDA) fft_float = make_unique>(); diff --git a/source/module_basis/module_pw/module_fft/fft_rcom.cpp b/source/module_basis/module_pw/module_fft/fft_rocm.cpp similarity index 76% rename from source/module_basis/module_pw/module_fft/fft_rcom.cpp rename to source/module_basis/module_pw/module_fft/fft_rocm.cpp index 225a6aae7a..9973c72901 100644 --- a/source/module_basis/module_pw/module_fft/fft_rcom.cpp +++ b/source/module_basis/module_pw/module_fft/fft_rocm.cpp @@ -1,10 +1,10 @@ -#include "fft_rcom.h" +#include "fft_rocm.h" #include "module_base/module_device/memory_op.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" namespace ModulePW { template -void FFT_RCOM::initfft(int nx_in, +void FFT_ROCM::initfft(int nx_in, int ny_in, int nz_in) { @@ -13,20 +13,20 @@ void FFT_RCOM::initfft(int nx_in, this->nz = nz_in; } template <> -void FFT_RCOM::setupFFT() +void FFT_ROCM::setupFFT() { hipfftPlan3d(&c_handle, this->nx, this->ny, this->nz, HIPFFT_C2C); resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz); } template <> -void FFT_RCOM::setupFFT() +void FFT_ROCM::setupFFT() { hipfftPlan3d(&z_handle, this->nx, this->ny, this->nz, HIPFFT_Z2Z); resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz); } template <> -void FFT_RCOM::cleanFFT() +void FFT_ROCM::cleanFFT() { if (c_handle) { @@ -35,7 +35,7 @@ void FFT_RCOM::cleanFFT() } } template <> -void FFT_RCOM::cleanFFT() +void FFT_ROCM::cleanFFT() { if (z_handle) { @@ -44,7 +44,7 @@ void FFT_RCOM::cleanFFT() } } template <> -void FFT_RCOM::clear() +void FFT_ROCM::clear() { this->cleanFFT(); if (c_auxr_3d != nullptr) @@ -54,7 +54,7 @@ void FFT_RCOM::clear() } } template <> -void FFT_RCOM::clear() +void FFT_ROCM::clear() { this->cleanFFT(); if (z_auxr_3d != nullptr) @@ -64,7 +64,7 @@ void FFT_RCOM::clear() } } template <> -void FFT_RCOM::fft3D_forward(std::complex* in, +void FFT_ROCM::fft3D_forward(std::complex* in, std::complex* out) const { CHECK_CUFFT(hipfftExecC2C(this->c_handle, @@ -73,7 +73,7 @@ void FFT_RCOM::fft3D_forward(std::complex* in, HIPFFT_FORWARD)); } template <> -void FFT_RCOM::fft3D_forward(std::complex* in, +void FFT_ROCM::fft3D_forward(std::complex* in, std::complex* out) const { CHECK_CUFFT(hipfftExecZ2Z(this->z_handle, @@ -82,7 +82,7 @@ void FFT_RCOM::fft3D_forward(std::complex* in, HIPFFT_FORWARD)); } template <> -void FFT_RCOM::fft3D_backward(std::complex* in, +void FFT_ROCM::fft3D_backward(std::complex* in, std::complex* out) const { CHECK_CUFFT(hipfftExecC2C(this->c_handle, @@ -91,7 +91,7 @@ void FFT_RCOM::fft3D_backward(std::complex* in, HIPFFT_BACKWARD)); } template <> -void FFT_RCOM::fft3D_backward(std::complex* in, +void FFT_ROCM::fft3D_backward(std::complex* in, std::complex* out) const { CHECK_CUFFT(hipfftExecZ2Z(this->z_handle, @@ -100,7 +100,11 @@ void FFT_RCOM::fft3D_backward(std::complex* in, HIPFFT_BACKWARD)); } template <> std::complex* -FFT_RCOM::get_auxr_3d_data() const {return this->c_auxr_3d;} +FFT_ROCM::get_auxr_3d_data() const {return this->c_auxr_3d;} template <> std::complex* -FFT_RCOM::get_auxr_3d_data() const {return this->z_auxr_3d;} +FFT_ROCM::get_auxr_3d_data() const {return this->z_auxr_3d;} +template FFT_ROCM::FFT_ROCM(); +template FFT_ROCM::~FFT_ROCM(); +template FFT_ROCM::FFT_ROCM(); +template FFT_ROCM::~FFT_ROCM(); }// namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_rcom.h b/source/module_basis/module_pw/module_fft/fft_rocm.h similarity index 87% rename from source/module_basis/module_pw/module_fft/fft_rcom.h rename to source/module_basis/module_pw/module_fft/fft_rocm.h index 09035b50ba..2d2cbd0c21 100644 --- a/source/module_basis/module_pw/module_fft/fft_rcom.h +++ b/source/module_basis/module_pw/module_fft/fft_rocm.h @@ -6,11 +6,11 @@ namespace ModulePW { template -class FFT_RCOM : public FFT_BASE +class FFT_ROCM : public FFT_BASE { public: - FFT_RCOM(){}; - ~FFT_RCOM(){}; + FFT_ROCM(){}; + ~FFT_ROCM(){}; void setupFFT() override; @@ -57,9 +57,5 @@ class FFT_RCOM : public FFT_BASE mutable std::complex* z_auxr_3d = nullptr; // fft space }; -template FFT_RCOM::FFT_RCOM(); -template FFT_RCOM::~FFT_RCOM(); -template FFT_RCOM::FFT_RCOM(); -template FFT_RCOM::~FFT_RCOM(); }// namespace ModulePW #endif