Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions source/module_base/module_device/rocm/memory_op.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ template struct cast_memory_op<std::complex<double>,
std::complex<float>,
base_device::DEVICE_GPU,
base_device::DEVICE_GPU>;
template struct cast_memory_op<std::complex<float>, float, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
template struct cast_memory_op<std::complex<double>, double, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
template struct cast_memory_op<float, float, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
template struct cast_memory_op<double, double, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
template struct cast_memory_op<float, double, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
Expand Down
2 changes: 1 addition & 1 deletion source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ if (USE_CUDA)
endif()
if (USE_ROCM)
list (APPEND FFT_SRC
module_fft/fft_rcom.cpp
module_fft/fft_rocm.cpp
)
endif()

Expand Down
6 changes: 3 additions & 3 deletions source/module_basis/module_pw/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "fft_cuda.h"
#endif
#if defined(__ROCM)
#include "fft_rcom.h"
#include "fft_rocm.h"
#endif

template<typename FFT_BASE, typename... Args>
Expand Down Expand Up @@ -89,9 +89,9 @@ void FFT_Bundle::initfft(int nx_in,
if (device=="gpu")
{
#if defined(__ROCM)
fft_float = new FFT_RCOM<float>();
fft_float = make_unique<FFT_ROCM<float>>();
fft_float->initfft(nx_in,ny_in,nz_in);
fft_double = new FFT_RCOM<double>();
fft_double = make_unique<FFT_ROCM<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
#elif defined(__CUDA)
fft_float = make_unique<FFT_CUDA<float>>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "fft_rcom.h"
#include "fft_rocm.h"
#include "module_base/module_device/memory_op.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
namespace ModulePW
{
template <typename FPTYPE>
void FFT_RCOM<FPTYPE>::initfft(int nx_in,
void FFT_ROCM<FPTYPE>::initfft(int nx_in,
int ny_in,
int nz_in)
{
Expand All @@ -13,20 +13,20 @@ void FFT_RCOM<FPTYPE>::initfft(int nx_in,
this->nz = nz_in;
}
template <>
void FFT_RCOM<float>::setupFFT()
void FFT_ROCM<float>::setupFFT()
{
hipfftPlan3d(&c_handle, this->nx, this->ny, this->nz, HIPFFT_C2C);
resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz);

}
template <>
void FFT_RCOM<double>::setupFFT()
void FFT_ROCM<double>::setupFFT()
{
hipfftPlan3d(&z_handle, this->nx, this->ny, this->nz, HIPFFT_Z2Z);
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
}
template <>
void FFT_RCOM<float>::cleanFFT()
void FFT_ROCM<float>::cleanFFT()
{
if (c_handle)
{
Expand All @@ -35,7 +35,7 @@ void FFT_RCOM<float>::cleanFFT()
}
}
template <>
void FFT_RCOM<double>::cleanFFT()
void FFT_ROCM<double>::cleanFFT()
{
if (z_handle)
{
Expand All @@ -44,7 +44,7 @@ void FFT_RCOM<double>::cleanFFT()
}
}
template <>
void FFT_RCOM<float>::clear()
void FFT_ROCM<float>::clear()
{
this->cleanFFT();
if (c_auxr_3d != nullptr)
Expand All @@ -54,7 +54,7 @@ void FFT_RCOM<float>::clear()
}
}
template <>
void FFT_RCOM<double>::clear()
void FFT_ROCM<double>::clear()
{
this->cleanFFT();
if (z_auxr_3d != nullptr)
Expand All @@ -64,7 +64,7 @@ void FFT_RCOM<double>::clear()
}
}
template <>
void FFT_RCOM<float>::fft3D_forward(std::complex<float>* in,
void FFT_ROCM<float>::fft3D_forward(std::complex<float>* in,
std::complex<float>* out) const
{
CHECK_CUFFT(hipfftExecC2C(this->c_handle,
Expand All @@ -73,7 +73,7 @@ void FFT_RCOM<float>::fft3D_forward(std::complex<float>* in,
HIPFFT_FORWARD));
}
template <>
void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in,
void FFT_ROCM<double>::fft3D_forward(std::complex<double>* in,
std::complex<double>* out) const
{
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle,
Expand All @@ -82,7 +82,7 @@ void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in,
HIPFFT_FORWARD));
}
template <>
void FFT_RCOM<float>::fft3D_backward(std::complex<float>* in,
void FFT_ROCM<float>::fft3D_backward(std::complex<float>* in,
std::complex<float>* out) const
{
CHECK_CUFFT(hipfftExecC2C(this->c_handle,
Expand All @@ -91,7 +91,7 @@ void FFT_RCOM<float>::fft3D_backward(std::complex<float>* in,
HIPFFT_BACKWARD));
}
template <>
void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in,
void FFT_ROCM<double>::fft3D_backward(std::complex<double>* in,
std::complex<double>* out) const
{
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle,
Expand All @@ -100,7 +100,11 @@ void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in,
HIPFFT_BACKWARD));
}
template <> std::complex<float>*
FFT_RCOM<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
FFT_ROCM<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
template <> std::complex<double>*
FFT_RCOM<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
FFT_ROCM<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
template FFT_ROCM<float>::FFT_ROCM();
template FFT_ROCM<float>::~FFT_ROCM();
template FFT_ROCM<double>::FFT_ROCM();
template FFT_ROCM<double>::~FFT_ROCM();
}// namespace ModulePW
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,5 @@ class FFT_ROCM : public FFT_BASE<FPTYPE>
mutable std::complex<double>* z_auxr_3d = nullptr; // fft space

};
template FFT_RCOM<float>::FFT_RCOM();
template FFT_ROCM<float>::~FFT_ROCM();
template FFT_RCOM<double>::FFT_RCOM();
template FFT_ROCM<double>::~FFT_ROCM();
}// namespace ModulePW
#endif
Loading