Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions source/module_base/module_device/rocm/memory_op.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ template struct cast_memory_op<std::complex<double>,
std::complex<float>,
base_device::DEVICE_GPU,
base_device::DEVICE_GPU>;
template struct cast_memory_op<std::complex<float>, float, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
template struct cast_memory_op<std::complex<double>, double, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
template struct cast_memory_op<float, float, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
template struct cast_memory_op<double, double, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
template struct cast_memory_op<float, double, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
Expand Down
4 changes: 2 additions & 2 deletions source/module_basis/module_pw/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ void FFT_Bundle::initfft(int nx_in,
if (device=="gpu")
{
#if defined(__ROCM)
fft_float = new FFT_RCOM<float>();
fft_float = make_unique<FFT_RCOM<float>>();
fft_float->initfft(nx_in,ny_in,nz_in);
fft_double = new FFT_RCOM<double>();
fft_double = make_unique<FFT_RCOM<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
#elif defined(__CUDA)
fft_float = make_unique<FFT_CUDA<float>>();
Expand Down
10 changes: 5 additions & 5 deletions source/module_basis/module_pw/module_fft/fft_rcom.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
namespace ModulePW
{
template <typename FPTYPE>
class FFT_ROCM : public FFT_BASE<FPTYPE>
class FFT_RCOM : public FFT_BASE<FPTYPE>
{
public:
FFT_ROCM(){};
~FFT_ROCM(){};
FFT_RCOM(){};
~FFT_RCOM(){};

void setupFFT() override;

Expand Down Expand Up @@ -58,8 +58,8 @@ class FFT_ROCM : public FFT_BASE<FPTYPE>

};
template FFT_RCOM<float>::FFT_RCOM();
template FFT_ROCM<float>::~FFT_ROCM();
template FFT_RCOM<float>::~FFT_RCOM();
template FFT_RCOM<double>::FFT_RCOM();
template FFT_ROCM<double>::~FFT_ROCM();
template FFT_RCOM<double>::~FFT_RCOM();
}// namespace ModulePW
#endif
Loading