Skip to content

Commit 593c30d

Browse files
authored
Refactor:Add cuda support for fft_bundle (#5508)
* add cuda support for fft * update the FFT * update the clear func
1 parent 4116e55 commit 593c30d

File tree

22 files changed

+424
-75
lines changed

22 files changed

+424
-75
lines changed

source/module_basis/module_pw/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@ if (ENABLE_FLOAT_FFTW)
33
module_fft/fft_cpu_float.cpp
44
)
55
endif()
6+
if (USE_CUDA)
7+
list (APPEND FFT_SRC
8+
module_fft/fft_cuda.cpp
9+
)
10+
endif()
11+
if (USE_ROCM)
12+
list (APPEND FFT_SRC
13+
module_fft/fft_rcom.cpp
14+
)
15+
endif()
16+
617
list(APPEND objects
718
fft.cpp
819
pw_basis.cpp

source/module_basis/module_pw/module_fft/fft_base.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ class FFT_BASE
3030
bool gamma_only_in,
3131
bool xprime_in = true);
3232

33+
virtual __attribute__((weak))
34+
void initfft(int nx_in,
35+
int ny_in,
36+
int nz_in);
37+
3338
/**
3439
* @brief Setup the fft Plan and data As pure virtual function.
3540
*

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
#include "fft_bundle.h"
33
#include "fft_cpu.h"
44
#include "module_base/module_device/device.h"
5-
// #if defined(__CUDA)
6-
// #include "fft_cuda.h"
7-
// #endif
8-
// #if defined(__ROCM)
9-
// #include "fft_rcom.h"
10-
// #endif
5+
#if defined(__CUDA)
6+
#include "fft_cuda.h"
7+
#endif
8+
#if defined(__ROCM)
9+
#include "fft_rcom.h"
10+
#endif
1111

1212
template<typename FFT_BASE, typename... Args>
1313
std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
@@ -16,6 +16,11 @@ std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
1616
}
1717
namespace ModulePW
1818
{
19+
FFT_Bundle::~FFT_Bundle()
20+
{
21+
this->clear();
22+
}
23+
1924
void FFT_Bundle::setfft(std::string device_in,std::string precision_in)
2025
{
2126
this->device = device_in;
@@ -83,13 +88,17 @@ void FFT_Bundle::initfft(int nx_in,
8388
}
8489
if (device=="gpu")
8590
{
86-
// #if defined(__ROCM)
87-
// fft_float = new FFT_RCOM<float>();
88-
// fft_double = new FFT_RCOM<double>();
89-
// #elif defined(__CUDA)
90-
// fft_float = make_unique<FFT_CUDA<float>>();
91-
// fft_double = make_unique<FFT_CUDA<double>>();
92-
// #endif
91+
#if defined(__ROCM)
92+
fft_float = new FFT_RCOM<float>();
93+
fft_float->initfft(nx_in,ny_in,nz_in);
94+
fft_double = new FFT_RCOM<double>();
95+
fft_double->initfft(nx_in,ny_in,nz_in);
96+
#elif defined(__CUDA)
97+
fft_float = make_unique<FFT_CUDA<float>>();
98+
fft_float->initfft(nx_in,ny_in,nz_in);
99+
fft_double = make_unique<FFT_CUDA<double>>();
100+
fft_double->initfft(nx_in,ny_in,nz_in);
101+
#endif
93102
}
94103

95104
}

source/module_basis/module_pw/module_fft/fft_bundle.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class FFT_Bundle
99
{
1010
public:
1111
FFT_Bundle(){};
12-
~FFT_Bundle(){};
12+
~FFT_Bundle();
1313
/**
1414
* @brief Constructor with device and precision.
1515
* @param device_in device type, cpu or gpu.

source/module_basis/module_pw/module_fft/fft_cpu_float.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ void FFT_CPU<float>::clear()
303303
fftw_free(c_auxg);
304304
c_auxg = nullptr;
305305
}
306-
if (z_auxr != nullptr)
306+
if (c_auxr != nullptr)
307307
{
308308
fftw_free(c_auxr);
309309
c_auxr = nullptr;
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#include "fft_cuda.h"
2+
#include "module_base/module_device/memory_op.h"
3+
#include "module_hamilt_pw/hamilt_pwdft/global.h"
4+
namespace ModulePW
5+
{
6+
template <typename FPTYPE>
7+
void FFT_CUDA<FPTYPE>::initfft(int nx_in,
8+
int ny_in,
9+
int nz_in)
10+
{
11+
this->nx = nx_in;
12+
this->ny = ny_in;
13+
this->nz = nz_in;
14+
}
15+
template <>
16+
void FFT_CUDA<float>::setupFFT()
17+
{
18+
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
19+
resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz);
20+
21+
}
22+
template <>
23+
void FFT_CUDA<double>::setupFFT()
24+
{
25+
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
26+
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
27+
}
28+
template <>
29+
void FFT_CUDA<float>::cleanFFT()
30+
{
31+
if (c_handle)
32+
{
33+
cufftDestroy(c_handle);
34+
c_handle = {};
35+
}
36+
}
37+
template <>
38+
void FFT_CUDA<double>::cleanFFT()
39+
{
40+
if (z_handle)
41+
{
42+
cufftDestroy(z_handle);
43+
z_handle = {};
44+
}
45+
}
46+
template <>
47+
void FFT_CUDA<float>::clear()
48+
{
49+
this->cleanFFT();
50+
if (c_auxr_3d != nullptr)
51+
{
52+
delmem_cd_op()(gpu_ctx, c_auxr_3d);
53+
c_auxr_3d = nullptr;
54+
}
55+
}
56+
template <>
57+
void FFT_CUDA<double>::clear()
58+
{
59+
this->cleanFFT();
60+
if (z_auxr_3d != nullptr)
61+
{
62+
delmem_zd_op()(gpu_ctx, z_auxr_3d);
63+
z_auxr_3d = nullptr;
64+
}
65+
}
66+
67+
template <>
68+
void FFT_CUDA<float>::fft3D_forward(std::complex<float>* in,
69+
std::complex<float>* out) const
70+
{
71+
CHECK_CUFFT(cufftExecC2C(this->c_handle,
72+
reinterpret_cast<cufftComplex*>(in),
73+
reinterpret_cast<cufftComplex*>(out),
74+
CUFFT_FORWARD));
75+
}
76+
template <>
77+
void FFT_CUDA<double>::fft3D_forward(std::complex<double>* in,
78+
std::complex<double>* out) const
79+
{
80+
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
81+
reinterpret_cast<cufftDoubleComplex*>(in),
82+
reinterpret_cast<cufftDoubleComplex*>(out),
83+
CUFFT_FORWARD));
84+
}
85+
template <>
86+
void FFT_CUDA<float>::fft3D_backward(std::complex<float>* in,
87+
std::complex<float>* out) const
88+
{
89+
CHECK_CUFFT(cufftExecC2C(this->c_handle,
90+
reinterpret_cast<cufftComplex*>(in),
91+
reinterpret_cast<cufftComplex*>(out),
92+
CUFFT_INVERSE));
93+
}
94+
95+
template <>
96+
void FFT_CUDA<double>::fft3D_backward(std::complex<double>* in,
97+
std::complex<double>* out) const
98+
{
99+
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
100+
reinterpret_cast<cufftDoubleComplex*>(in),
101+
reinterpret_cast<cufftDoubleComplex*>(out),
102+
CUFFT_INVERSE));
103+
}
104+
template <> std::complex<float>*
105+
FFT_CUDA<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
106+
template <> std::complex<double>*
107+
FFT_CUDA<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
108+
}// namespace ModulePW
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#include "fft_base.h"
2+
#include "cufft.h"
3+
#include "cuda_runtime.h"
4+
5+
#ifndef FFT_CUDA_H
6+
#define FFT_CUDA_H
7+
namespace ModulePW
8+
{
9+
template <typename FPTYPE>
10+
class FFT_CUDA : public FFT_BASE<FPTYPE>
11+
{
12+
public:
13+
FFT_CUDA(){};
14+
~FFT_CUDA(){};
15+
16+
void setupFFT() override;
17+
18+
void clear() override;
19+
20+
void cleanFFT() override;
21+
22+
/**
23+
* @brief Initialize the fft parameters
24+
* @param nx_in number of grid points in x direction
25+
* @param ny_in number of grid points in y direction
26+
* @param nz_in number of grid points in z direction
27+
*
28+
*/
29+
void initfft(int nx_in,
30+
int ny_in,
31+
int nz_in) override;
32+
33+
/**
34+
* @brief Get the real space data
35+
* @return real space data
36+
*/
37+
std::complex<FPTYPE>* get_auxr_3d_data() const override;
38+
39+
/**
40+
* @brief Forward FFT in 3D
41+
* @param in input data, complex FPTYPE
42+
* @param out output data, complex FPTYPE
43+
*
44+
* This function performs the forward FFT in 3D.
45+
*/
46+
void fft3D_forward(std::complex<FPTYPE>* in,
47+
std::complex<FPTYPE>* out) const override;
48+
/**
49+
* @brief Backward FFT in 3D
50+
* @param in input data, complex FPTYPE
51+
* @param out output data, complex FPTYPE
52+
*
53+
* This function performs the backward FFT in 3D.
54+
*/
55+
void fft3D_backward(std::complex<FPTYPE>* in,
56+
std::complex<FPTYPE>* out) const override;
57+
private:
58+
cufftHandle c_handle = {};
59+
cufftHandle z_handle = {};
60+
61+
std::complex<float>* c_auxr_3d = nullptr; // fft space
62+
std::complex<double>* z_auxr_3d = nullptr; // fft space
63+
64+
};
65+
template FFT_CUDA<float>::FFT_CUDA();
66+
template FFT_CUDA<float>::~FFT_CUDA();
67+
template FFT_CUDA<double>::FFT_CUDA();
68+
template FFT_CUDA<double>::~FFT_CUDA();
69+
} // namespace ModulePW
70+
#endif
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#include "fft_rcom.h"
2+
#include "module_base/module_device/memory_op.h"
3+
#include "module_hamilt_pw/hamilt_pwdft/global.h"
4+
namespace ModulePW
5+
{
6+
template <typename FPTYPE>
7+
void FFT_RCOM<FPTYPE>::initfft(int nx_in,
8+
int ny_in,
9+
int nz_in)
10+
{
11+
this->nx = nx_in;
12+
this->ny = ny_in;
13+
this->nz = nz_in;
14+
}
15+
template <>
16+
void FFT_RCOM<float>::setupFFT()
17+
{
18+
hipfftPlan3d(&c_handle, this->nx, this->ny, this->nz, HIPFFT_C2C);
19+
resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz);
20+
21+
}
22+
template <>
23+
void FFT_RCOM<double>::setupFFT()
24+
{
25+
hipfftPlan3d(&z_handle, this->nx, this->ny, this->nz, HIPFFT_Z2Z);
26+
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
27+
}
28+
template <>
29+
void FFT_RCOM<float>::cleanFFT()
30+
{
31+
if (c_handle)
32+
{
33+
hipfftDestroy(c_handle);
34+
c_handle = {};
35+
}
36+
}
37+
template <>
38+
void FFT_RCOM<double>::cleanFFT()
39+
{
40+
if (z_handle)
41+
{
42+
hipfftDestroy(z_handle);
43+
z_handle = {};
44+
}
45+
}
46+
template <>
47+
void FFT_RCOM<float>::clear()
48+
{
49+
this->cleanFFT();
50+
if (c_auxr_3d != nullptr)
51+
{
52+
delmem_cd_op()(gpu_ctx, c_auxr_3d);
53+
c_auxr_3d = nullptr;
54+
}
55+
}
56+
template <>
57+
void FFT_RCOM<double>::clear()
58+
{
59+
this->cleanFFT();
60+
if (z_auxr_3d != nullptr)
61+
{
62+
delmem_zd_op()(gpu_ctx, z_auxr_3d);
63+
z_auxr_3d = nullptr;
64+
}
65+
}
66+
template <>
67+
void FFT_RCOM<float>::fft3D_forward(std::complex<float>* in,
68+
std::complex<float>* out) const
69+
{
70+
CHECK_CUFFT(hipfftExecC2C(this->c_handle,
71+
reinterpret_cast<hipfftComplex*>(in),
72+
reinterpret_cast<hipfftComplex*>(out),
73+
HIPFFT_FORWARD));
74+
}
75+
template <>
76+
void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in,
77+
std::complex<double>* out) const
78+
{
79+
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle,
80+
reinterpret_cast<hipfftDoubleComplex*>(in),
81+
reinterpret_cast<hipfftDoubleComplex*>(out),
82+
HIPFFT_FORWARD));
83+
}
84+
template <>
85+
void FFT_RCOM<float>::fft3D_backward(std::complex<float>* in,
86+
std::complex<float>* out) const
87+
{
88+
CHECK_CUFFT(hipfftExecC2C(this->c_handle,
89+
reinterpret_cast<hipfftComplex*>(in),
90+
reinterpret_cast<hipfftComplex*>(out),
91+
HIPFFT_BACKWARD));
92+
}
93+
template <>
94+
void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in,
95+
std::complex<double>* out) const
96+
{
97+
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle,
98+
reinterpret_cast<hipfftDoubleComplex*>(in),
99+
reinterpret_cast<hipfftDoubleComplex*>(out),
100+
HIPFFT_BACKWARD));
101+
}
102+
template <> std::complex<float>*
103+
FFT_RCOM<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
104+
template <> std::complex<double>*
105+
FFT_RCOM<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
106+
}// namespace ModulePW

0 commit comments

Comments
 (0)