1+ #include " fft_rcom.h"
2+ #include " module_base/module_device/memory_op.h"
3+ #include " module_hamilt_pw/hamilt_pwdft/global.h"
4+ namespace ModulePW
5+ {
6+ template <typename FPTYPE>
7+ void FFT_RCOM<FPTYPE>::initfft(int nx_in,
8+ int ny_in,
9+ int nz_in)
10+ {
11+ this ->nx = nx_in;
12+ this ->ny = ny_in;
13+ this ->nz = nz_in;
14+ }
15+ template <>
16+ void FFT_RCOM<float >::setupFFT()
17+ {
18+ hipfftPlan3d (&c_handle, this ->nx , this ->ny , this ->nz , HIPFFT_C2C);
19+ resmem_cd_op ()(gpu_ctx, this ->c_auxr_3d , this ->nx * this ->ny * this ->nz );
20+
21+ }
22+ template <>
23+ void FFT_RCOM<double >::setupFFT()
24+ {
25+ hipfftPlan3d (&z_handle, this ->nx , this ->ny , this ->nz , HIPFFT_Z2Z);
26+ resmem_zd_op ()(gpu_ctx, this ->z_auxr_3d , this ->nx * this ->ny * this ->nz );
27+ }
28+ template <>
29+ void FFT_RCOM<float >::cleanFFT()
30+ {
31+ if (c_handle)
32+ {
33+ hipfftDestroy (c_handle);
34+ c_handle = {};
35+ }
36+ }
37+ template <>
38+ void FFT_RCOM<double >::cleanFFT()
39+ {
40+ if (z_handle)
41+ {
42+ hipfftDestroy (z_handle);
43+ z_handle = {};
44+ }
45+ }
46+ template <>
47+ void FFT_RCOM<float >::clear()
48+ {
49+ this ->cleanFFT ();
50+ if (c_auxr_3d != nullptr )
51+ {
52+ delmem_cd_op ()(gpu_ctx, c_auxr_3d);
53+ c_auxr_3d = nullptr ;
54+ }
55+ }
56+ template <>
57+ void FFT_RCOM<double >::clear()
58+ {
59+ this ->cleanFFT ();
60+ if (z_auxr_3d != nullptr )
61+ {
62+ delmem_zd_op ()(gpu_ctx, z_auxr_3d);
63+ z_auxr_3d = nullptr ;
64+ }
65+ }
66+ template <>
67+ void FFT_RCOM<float >::fft3D_forward(std::complex <float >* in,
68+ std::complex <float >* out) const
69+ {
70+ CHECK_CUFFT (hipfftExecC2C (this ->c_handle ,
71+ reinterpret_cast <hipfftComplex*>(in),
72+ reinterpret_cast <hipfftComplex*>(out),
73+ HIPFFT_FORWARD));
74+ }
75+ template <>
76+ void FFT_RCOM<double >::fft3D_forward(std::complex <double >* in,
77+ std::complex <double >* out) const
78+ {
79+ CHECK_CUFFT (hipfftExecZ2Z (this ->z_handle ,
80+ reinterpret_cast <hipfftDoubleComplex*>(in),
81+ reinterpret_cast <hipfftDoubleComplex*>(out),
82+ HIPFFT_FORWARD));
83+ }
84+ template <>
85+ void FFT_RCOM<float >::fft3D_backward(std::complex <float >* in,
86+ std::complex <float >* out) const
87+ {
88+ CHECK_CUFFT (hipfftExecC2C (this ->c_handle ,
89+ reinterpret_cast <hipfftComplex*>(in),
90+ reinterpret_cast <hipfftComplex*>(out),
91+ HIPFFT_BACKWARD));
92+ }
93+ template <>
94+ void FFT_RCOM<double >::fft3D_backward(std::complex <double >* in,
95+ std::complex <double >* out) const
96+ {
97+ CHECK_CUFFT (hipfftExecZ2Z (this ->z_handle ,
98+ reinterpret_cast <hipfftDoubleComplex*>(in),
99+ reinterpret_cast <hipfftDoubleComplex*>(out),
100+ HIPFFT_BACKWARD));
101+ }
102+ template <> std::complex <float >*
103+ FFT_RCOM<float >::get_auxr_3d_data() const {return this ->c_auxr_3d ;}
104+ template <> std::complex <double >*
105+ FFT_RCOM<double >::get_auxr_3d_data() const {return this ->z_auxr_3d ;}
106+ }// namespace ModulePW
0 commit comments