diff --git a/source/module_base/module_device/rocm/memory_op.hip.cu b/source/module_base/module_device/rocm/memory_op.hip.cu index 678acbc048..1909cfb771 100644 --- a/source/module_base/module_device/rocm/memory_op.hip.cu +++ b/source/module_base/module_device/rocm/memory_op.hip.cu @@ -219,6 +219,8 @@ template struct cast_memory_op, std::complex, base_device::DEVICE_GPU, base_device::DEVICE_GPU>; +template struct cast_memory_op, float, base_device::DEVICE_GPU, base_device::DEVICE_GPU>; +template struct cast_memory_op, double, base_device::DEVICE_GPU, base_device::DEVICE_GPU>; template struct cast_memory_op; template struct cast_memory_op; template struct cast_memory_op; diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index 78f9824d8b..a95eca4917 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -10,7 +10,7 @@ if (USE_CUDA) endif() if (USE_ROCM) list (APPEND FFT_SRC - module_fft/fft_rcom.cpp + module_fft/fft_rocm.cpp ) endif() diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 31d29c32f9..a7be7d988d 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -6,7 +6,7 @@ #include "fft_cuda.h" #endif #if defined(__ROCM) -#include "fft_rcom.h" +#include "fft_rocm.h" #endif template @@ -89,9 +89,9 @@ void FFT_Bundle::initfft(int nx_in, if (device=="gpu") { #if defined(__ROCM) - fft_float = new FFT_RCOM(); + fft_float = make_unique>(); fft_float->initfft(nx_in,ny_in,nz_in); - fft_double = new FFT_RCOM(); + fft_double = make_unique>(); fft_double->initfft(nx_in,ny_in,nz_in); #elif defined(__CUDA) fft_float = make_unique>(); diff --git a/source/module_basis/module_pw/module_fft/fft_rcom.cpp b/source/module_basis/module_pw/module_fft/fft_rocm.cpp similarity index 76% rename from source/module_basis/module_pw/module_fft/fft_rcom.cpp rename to source/module_basis/module_pw/module_fft/fft_rocm.cpp index 225a6aae7a..9973c72901 100644 --- a/source/module_basis/module_pw/module_fft/fft_rcom.cpp +++ b/source/module_basis/module_pw/module_fft/fft_rocm.cpp @@ -1,10 +1,10 @@ -#include "fft_rcom.h" +#include "fft_rocm.h" #include "module_base/module_device/memory_op.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" namespace ModulePW { template -void FFT_RCOM::initfft(int nx_in, +void FFT_ROCM::initfft(int nx_in, int ny_in, int nz_in) { @@ -13,20 +13,20 @@ void FFT_RCOM::initfft(int nx_in, this->nz = nz_in; } template <> -void FFT_RCOM::setupFFT() +void FFT_ROCM::setupFFT() { hipfftPlan3d(&c_handle, this->nx, this->ny, this->nz, HIPFFT_C2C); resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz); } template <> -void FFT_RCOM::setupFFT() +void FFT_ROCM::setupFFT() { hipfftPlan3d(&z_handle, this->nx, this->ny, this->nz, HIPFFT_Z2Z); resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz); } template <> -void FFT_RCOM::cleanFFT() +void FFT_ROCM::cleanFFT() { if (c_handle) { @@ -35,7 +35,7 @@ void FFT_RCOM::cleanFFT() } } template <> -void FFT_RCOM::cleanFFT() +void FFT_ROCM::cleanFFT() { if (z_handle) { @@ -44,7 +44,7 @@ void FFT_RCOM::cleanFFT() } } template <> -void FFT_RCOM::clear() +void FFT_ROCM::clear() { this->cleanFFT(); if (c_auxr_3d != nullptr) @@ -54,7 +54,7 @@ void FFT_RCOM::clear() } } template <> -void FFT_RCOM::clear() +void FFT_ROCM::clear() { this->cleanFFT(); if (z_auxr_3d != nullptr) @@ -64,7 +64,7 @@ void FFT_RCOM::clear() } } template <> -void FFT_RCOM::fft3D_forward(std::complex* in, +void FFT_ROCM::fft3D_forward(std::complex* in, std::complex* out) const { CHECK_CUFFT(hipfftExecC2C(this->c_handle, @@ -73,7 +73,7 @@ void FFT_RCOM::fft3D_forward(std::complex* in, HIPFFT_FORWARD)); } template <> -void FFT_RCOM::fft3D_forward(std::complex* in, +void FFT_ROCM::fft3D_forward(std::complex* in, std::complex* out) const { CHECK_CUFFT(hipfftExecZ2Z(this->z_handle, @@ -82,7 +82,7 @@ void FFT_RCOM::fft3D_forward(std::complex* in, HIPFFT_FORWARD)); } template <> -void FFT_RCOM::fft3D_backward(std::complex* in, +void FFT_ROCM::fft3D_backward(std::complex* in, std::complex* out) const { CHECK_CUFFT(hipfftExecC2C(this->c_handle, @@ -91,7 +91,7 @@ void FFT_RCOM::fft3D_backward(std::complex* in, HIPFFT_BACKWARD)); } template <> -void FFT_RCOM::fft3D_backward(std::complex* in, +void FFT_ROCM::fft3D_backward(std::complex* in, std::complex* out) const { CHECK_CUFFT(hipfftExecZ2Z(this->z_handle, @@ -100,7 +100,11 @@ void FFT_RCOM::fft3D_backward(std::complex* in, HIPFFT_BACKWARD)); } template <> std::complex* -FFT_RCOM::get_auxr_3d_data() const {return this->c_auxr_3d;} +FFT_ROCM::get_auxr_3d_data() const {return this->c_auxr_3d;} template <> std::complex* -FFT_RCOM::get_auxr_3d_data() const {return this->z_auxr_3d;} +FFT_ROCM::get_auxr_3d_data() const {return this->z_auxr_3d;} +template FFT_ROCM::FFT_ROCM(); +template FFT_ROCM::~FFT_ROCM(); +template FFT_ROCM::FFT_ROCM(); +template FFT_ROCM::~FFT_ROCM(); }// namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/module_fft/fft_rcom.h b/source/module_basis/module_pw/module_fft/fft_rocm.h similarity index 92% rename from source/module_basis/module_pw/module_fft/fft_rcom.h rename to source/module_basis/module_pw/module_fft/fft_rocm.h index 64ad13329e..2d2cbd0c21 100644 --- a/source/module_basis/module_pw/module_fft/fft_rcom.h +++ b/source/module_basis/module_pw/module_fft/fft_rocm.h @@ -57,9 +57,5 @@ class FFT_ROCM : public FFT_BASE mutable std::complex* z_auxr_3d = nullptr; // fft space }; -template FFT_RCOM::FFT_RCOM(); -template FFT_ROCM::~FFT_ROCM(); -template FFT_RCOM::FFT_RCOM(); -template FFT_ROCM::~FFT_ROCM(); }// namespace ModulePW #endif