Skip to content

Commit 0b88bc8

Browse files
committed
change the logic of cuda forward and backward
1 parent d4492a5 commit 0b88bc8

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

source/module_base/module_fft/fft_cuda.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,22 +77,25 @@ void FFT_CUDA<float>::fft3D_forward(std::complex<float>* in, std::complex<float>
7777
CUFFT_FORWARD));
7878
}
7979
template <>
80-
void FFT_CUDA<double>::fft3D_backward(std::complex<double>* in, std::complex<double>* out) const
80+
void FFT_CUDA<double>::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const
8181
{
8282
CHECK_CUFFT(cufftExecZ2Z(this->z_handle, reinterpret_cast<cufftDoubleComplex*>(in),
8383
reinterpret_cast<cufftDoubleComplex*>(out), CUFFT_FORWARD));
8484
}
8585
template <>
8686
void FFT_CUDA<float>::fft3D_backward(std::complex<float>* in, std::complex<float>* out) const
8787
{
88-
CHECK_CUFFT(cufftExecC2C(this->c_handle, reinterpret_cast<cufftComplex*>(in), reinterpret_cast<cufftComplex*>(out),
89-
CUFFT_INVERSE));
88+
CHECK_CUFFT(cufftExecC2C(this->c_handle, reinterpret_cast<cufftComplex*>(in),
89+
reinterpret_cast<cufftComplex*>(out),CUFFT_INVERSE));
9090
}
91+
9192
template <>
92-
void FFT_CUDA<double>::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const
93+
void FFT_CUDA<double>::fft3D_backward(std::complex<double>* in, std::complex<double>* out) const
9394
{
9495
CHECK_CUFFT(cufftExecZ2Z(this->z_handle, reinterpret_cast<cufftDoubleComplex*>(in),
9596
reinterpret_cast<cufftDoubleComplex*>(out), CUFFT_INVERSE));
9697
}
98+
99+
97100
template FFT_CUDA<float>::FFT_CUDA();
98101
template FFT_CUDA<double>::FFT_CUDA();

source/module_base/module_fft/fft_rcom.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ void FFT_RCOM<float>::fft3D_forward(std::complex<float>* in, std::complex<float>
7777
reinterpret_cast<hipfftComplex*>(out), HIPFFT_FORWARD));
7878
}
7979
template <>
80-
void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in, std::complex<double>* out) const
80+
void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const
8181
{
8282
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle, reinterpret_cast<hipfftDoubleComplex*>(in),
8383
reinterpret_cast<hipfftDoubleComplex*>(out), HIPFFT_FORWARD));
@@ -89,10 +89,12 @@ void FFT_RCOM<float>::fft3D_backward(std::complex<float>* in, std::complex<float
8989
reinterpret_cast<hipfftComplex*>(out), HIPFFT_BACKWARD));
9090
}
9191
template <>
92-
void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const
92+
void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in, std::complex<double>* out) const
9393
{
9494
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle, reinterpret_cast<hipfftDoubleComplex*>(in),
9595
reinterpret_cast<hipfftDoubleComplex*>(out), HIPFFT_BACKWARD));
9696
}
97+
98+
9799
template FFT_RCOM<float>::FFT_RCOM();
98100
template FFT_RCOM<double>::FFT_RCOM();

0 commit comments

Comments
 (0)