Skip to content

Commit 0fd832e

Browse files
A-006Qianruipku
andauthored
update the DCU compile (#5563)
* update the duc compile * update the FFT_CUDA in the fft_bundle.cpp * update the rcom to rocm --------- Co-authored-by: Qianrui Liu <[email protected]>
1 parent 5e94cb2 commit 0fd832e

File tree

5 files changed

+24
-22
lines changed

5 files changed

+24
-22
lines changed

source/module_base/module_device/rocm/memory_op.hip.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ template struct cast_memory_op<std::complex<double>,
219219
std::complex<float>,
220220
base_device::DEVICE_GPU,
221221
base_device::DEVICE_GPU>;
222+
template struct cast_memory_op<std::complex<float>, float, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
223+
template struct cast_memory_op<std::complex<double>, double, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
222224
template struct cast_memory_op<float, float, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
223225
template struct cast_memory_op<double, double, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
224226
template struct cast_memory_op<float, double, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;

source/module_basis/module_pw/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ if (USE_CUDA)
1010
endif()
1111
if (USE_ROCM)
1212
list (APPEND FFT_SRC
13-
module_fft/fft_rcom.cpp
13+
module_fft/fft_rocm.cpp
1414
)
1515
endif()
1616

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "fft_cuda.h"
77
#endif
88
#if defined(__ROCM)
9-
#include "fft_rcom.h"
9+
#include "fft_rocm.h"
1010
#endif
1111

1212
template<typename FFT_BASE, typename... Args>
@@ -89,9 +89,9 @@ void FFT_Bundle::initfft(int nx_in,
8989
if (device=="gpu")
9090
{
9191
#if defined(__ROCM)
92-
fft_float = new FFT_RCOM<float>();
92+
fft_float = make_unique<FFT_ROCM<float>>();
9393
fft_float->initfft(nx_in,ny_in,nz_in);
94-
fft_double = new FFT_RCOM<double>();
94+
fft_double = make_unique<FFT_ROCM<double>>();
9595
fft_double->initfft(nx_in,ny_in,nz_in);
9696
#elif defined(__CUDA)
9797
fft_float = make_unique<FFT_CUDA<float>>();
Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
#include "fft_rcom.h"
1+
#include "fft_rocm.h"
22
#include "module_base/module_device/memory_op.h"
33
#include "module_hamilt_pw/hamilt_pwdft/global.h"
44
namespace ModulePW
55
{
66
template <typename FPTYPE>
7-
void FFT_RCOM<FPTYPE>::initfft(int nx_in,
7+
void FFT_ROCM<FPTYPE>::initfft(int nx_in,
88
int ny_in,
99
int nz_in)
1010
{
@@ -13,20 +13,20 @@ void FFT_RCOM<FPTYPE>::initfft(int nx_in,
1313
this->nz = nz_in;
1414
}
1515
template <>
16-
void FFT_RCOM<float>::setupFFT()
16+
void FFT_ROCM<float>::setupFFT()
1717
{
1818
hipfftPlan3d(&c_handle, this->nx, this->ny, this->nz, HIPFFT_C2C);
1919
resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz);
2020

2121
}
2222
template <>
23-
void FFT_RCOM<double>::setupFFT()
23+
void FFT_ROCM<double>::setupFFT()
2424
{
2525
hipfftPlan3d(&z_handle, this->nx, this->ny, this->nz, HIPFFT_Z2Z);
2626
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
2727
}
2828
template <>
29-
void FFT_RCOM<float>::cleanFFT()
29+
void FFT_ROCM<float>::cleanFFT()
3030
{
3131
if (c_handle)
3232
{
@@ -35,7 +35,7 @@ void FFT_RCOM<float>::cleanFFT()
3535
}
3636
}
3737
template <>
38-
void FFT_RCOM<double>::cleanFFT()
38+
void FFT_ROCM<double>::cleanFFT()
3939
{
4040
if (z_handle)
4141
{
@@ -44,7 +44,7 @@ void FFT_RCOM<double>::cleanFFT()
4444
}
4545
}
4646
template <>
47-
void FFT_RCOM<float>::clear()
47+
void FFT_ROCM<float>::clear()
4848
{
4949
this->cleanFFT();
5050
if (c_auxr_3d != nullptr)
@@ -54,7 +54,7 @@ void FFT_RCOM<float>::clear()
5454
}
5555
}
5656
template <>
57-
void FFT_RCOM<double>::clear()
57+
void FFT_ROCM<double>::clear()
5858
{
5959
this->cleanFFT();
6060
if (z_auxr_3d != nullptr)
@@ -64,7 +64,7 @@ void FFT_RCOM<double>::clear()
6464
}
6565
}
6666
template <>
67-
void FFT_RCOM<float>::fft3D_forward(std::complex<float>* in,
67+
void FFT_ROCM<float>::fft3D_forward(std::complex<float>* in,
6868
std::complex<float>* out) const
6969
{
7070
CHECK_CUFFT(hipfftExecC2C(this->c_handle,
@@ -73,7 +73,7 @@ void FFT_RCOM<float>::fft3D_forward(std::complex<float>* in,
7373
HIPFFT_FORWARD));
7474
}
7575
template <>
76-
void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in,
76+
void FFT_ROCM<double>::fft3D_forward(std::complex<double>* in,
7777
std::complex<double>* out) const
7878
{
7979
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle,
@@ -82,7 +82,7 @@ void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in,
8282
HIPFFT_FORWARD));
8383
}
8484
template <>
85-
void FFT_RCOM<float>::fft3D_backward(std::complex<float>* in,
85+
void FFT_ROCM<float>::fft3D_backward(std::complex<float>* in,
8686
std::complex<float>* out) const
8787
{
8888
CHECK_CUFFT(hipfftExecC2C(this->c_handle,
@@ -91,7 +91,7 @@ void FFT_RCOM<float>::fft3D_backward(std::complex<float>* in,
9191
HIPFFT_BACKWARD));
9292
}
9393
template <>
94-
void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in,
94+
void FFT_ROCM<double>::fft3D_backward(std::complex<double>* in,
9595
std::complex<double>* out) const
9696
{
9797
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle,
@@ -100,7 +100,11 @@ void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in,
100100
HIPFFT_BACKWARD));
101101
}
102102
template <> std::complex<float>*
103-
FFT_RCOM<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
103+
FFT_ROCM<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
104104
template <> std::complex<double>*
105-
FFT_RCOM<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
105+
FFT_ROCM<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
106+
template FFT_ROCM<float>::FFT_ROCM();
107+
template FFT_ROCM<float>::~FFT_ROCM();
108+
template FFT_ROCM<double>::FFT_ROCM();
109+
template FFT_ROCM<double>::~FFT_ROCM();
106110
}// namespace ModulePW

source/module_basis/module_pw/module_fft/fft_rcom.h renamed to source/module_basis/module_pw/module_fft/fft_rocm.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,5 @@ class FFT_ROCM : public FFT_BASE<FPTYPE>
5757
mutable std::complex<double>* z_auxr_3d = nullptr; // fft space
5858

5959
};
60-
template FFT_RCOM<float>::FFT_RCOM();
61-
template FFT_ROCM<float>::~FFT_ROCM();
62-
template FFT_RCOM<double>::FFT_RCOM();
63-
template FFT_ROCM<double>::~FFT_ROCM();
6460
}// namespace ModulePW
6561
#endif

0 commit comments

Comments
 (0)