1- #include " fft_rcom .h"
1+ #include " fft_rocm .h"
22#include " module_base/module_device/memory_op.h"
33#include " module_hamilt_pw/hamilt_pwdft/global.h"
44namespace ModulePW
55{
66template <typename FPTYPE>
7- void FFT_RCOM <FPTYPE>::initfft(int nx_in,
7+ void FFT_ROCM <FPTYPE>::initfft(int nx_in,
88 int ny_in,
99 int nz_in)
1010{
@@ -13,20 +13,20 @@ void FFT_RCOM<FPTYPE>::initfft(int nx_in,
1313 this ->nz = nz_in;
1414}
1515template <>
16- void FFT_RCOM <float >::setupFFT()
16+ void FFT_ROCM <float >::setupFFT()
1717{
1818 hipfftPlan3d (&c_handle, this ->nx , this ->ny , this ->nz , HIPFFT_C2C);
1919 resmem_cd_op ()(gpu_ctx, this ->c_auxr_3d , this ->nx * this ->ny * this ->nz );
2020
2121}
2222template <>
23- void FFT_RCOM <double >::setupFFT()
23+ void FFT_ROCM <double >::setupFFT()
2424{
2525 hipfftPlan3d (&z_handle, this ->nx , this ->ny , this ->nz , HIPFFT_Z2Z);
2626 resmem_zd_op ()(gpu_ctx, this ->z_auxr_3d , this ->nx * this ->ny * this ->nz );
2727}
2828template <>
29- void FFT_RCOM <float >::cleanFFT()
29+ void FFT_ROCM <float >::cleanFFT()
3030{
3131 if (c_handle)
3232 {
@@ -35,7 +35,7 @@ void FFT_RCOM<float>::cleanFFT()
3535 }
3636}
3737template <>
38- void FFT_RCOM <double >::cleanFFT()
38+ void FFT_ROCM <double >::cleanFFT()
3939{
4040 if (z_handle)
4141 {
@@ -44,7 +44,7 @@ void FFT_RCOM<double>::cleanFFT()
4444 }
4545}
4646template <>
47- void FFT_RCOM <float >::clear()
47+ void FFT_ROCM <float >::clear()
4848{
4949 this ->cleanFFT ();
5050 if (c_auxr_3d != nullptr )
@@ -54,7 +54,7 @@ void FFT_RCOM<float>::clear()
5454 }
5555}
5656template <>
57- void FFT_RCOM <double >::clear()
57+ void FFT_ROCM <double >::clear()
5858{
5959 this ->cleanFFT ();
6060 if (z_auxr_3d != nullptr )
@@ -64,7 +64,7 @@ void FFT_RCOM<double>::clear()
6464 }
6565}
6666template <>
67- void FFT_RCOM <float >::fft3D_forward(std::complex <float >* in,
67+ void FFT_ROCM <float >::fft3D_forward(std::complex <float >* in,
6868 std::complex <float >* out) const
6969{
7070 CHECK_CUFFT (hipfftExecC2C (this ->c_handle ,
@@ -73,7 +73,7 @@ void FFT_RCOM<float>::fft3D_forward(std::complex<float>* in,
7373 HIPFFT_FORWARD));
7474}
7575template <>
76- void FFT_RCOM <double >::fft3D_forward(std::complex <double >* in,
76+ void FFT_ROCM <double >::fft3D_forward(std::complex <double >* in,
7777 std::complex <double >* out) const
7878{
7979 CHECK_CUFFT (hipfftExecZ2Z (this ->z_handle ,
@@ -82,7 +82,7 @@ void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in,
8282 HIPFFT_FORWARD));
8383}
8484template <>
85- void FFT_RCOM <float >::fft3D_backward(std::complex <float >* in,
85+ void FFT_ROCM <float >::fft3D_backward(std::complex <float >* in,
8686 std::complex <float >* out) const
8787{
8888 CHECK_CUFFT (hipfftExecC2C (this ->c_handle ,
@@ -91,7 +91,7 @@ void FFT_RCOM<float>::fft3D_backward(std::complex<float>* in,
9191 HIPFFT_BACKWARD));
9292}
9393template <>
94- void FFT_RCOM <double >::fft3D_backward(std::complex <double >* in,
94+ void FFT_ROCM <double >::fft3D_backward(std::complex <double >* in,
9595 std::complex <double >* out) const
9696{
9797 CHECK_CUFFT (hipfftExecZ2Z (this ->z_handle ,
@@ -100,7 +100,11 @@ void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in,
100100 HIPFFT_BACKWARD));
101101}
102102template <> std::complex <float >*
103- FFT_RCOM <float >::get_auxr_3d_data() const {return this ->c_auxr_3d ;}
103+ FFT_ROCM <float >::get_auxr_3d_data() const {return this ->c_auxr_3d ;}
104104template <> std::complex <double >*
105- FFT_RCOM<double >::get_auxr_3d_data() const {return this ->z_auxr_3d ;}
105+ FFT_ROCM<double >::get_auxr_3d_data() const {return this ->z_auxr_3d ;}
106+ template FFT_ROCM<float >::FFT_ROCM();
107+ template FFT_ROCM<float >::~FFT_ROCM ();
108+ template FFT_ROCM<double >::FFT_ROCM();
109+ template FFT_ROCM<double >::~FFT_ROCM ();
106110}// namespace ModulePW
0 commit comments