Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ if (ENABLE_FLOAT_FFTW)
module_fft/fft_cpu_float.cpp
)
endif()
if (USE_CUDA)
list (APPEND FFT_SRC
module_fft/fft_cuda.cpp
)
endif()
if (USE_ROCM)
list (APPEND FFT_SRC
module_fft/fft_rcom.cpp
)
endif()

list(APPEND objects
fft.cpp
pw_basis.cpp
Expand Down
5 changes: 5 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class FFT_BASE
bool gamma_only_in,
bool xprime_in = true);

virtual __attribute__((weak))
void initfft(int nx_in,
int ny_in,
int nz_in);

/**
* @brief Setup the fft Plan and data As pure virtual function.
*
Expand Down
35 changes: 22 additions & 13 deletions source/module_basis/module_pw/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#include "fft_bundle.h"
#include "fft_cpu.h"
#include "module_base/module_device/device.h"
// #if defined(__CUDA)
// #include "fft_cuda.h"
// #endif
// #if defined(__ROCM)
// #include "fft_rcom.h"
// #endif
#if defined(__CUDA)
#include "fft_cuda.h"
#endif
#if defined(__ROCM)
#include "fft_rcom.h"
#endif

template<typename FFT_BASE, typename... Args>
std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
Expand All @@ -16,6 +16,11 @@ std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
}
namespace ModulePW
{
FFT_Bundle::~FFT_Bundle()
{
this->clear();
}

void FFT_Bundle::setfft(std::string device_in,std::string precision_in)
{
this->device = device_in;
Expand Down Expand Up @@ -83,13 +88,17 @@ void FFT_Bundle::initfft(int nx_in,
}
if (device=="gpu")
{
// #if defined(__ROCM)
// fft_float = new FFT_RCOM<float>();
// fft_double = new FFT_RCOM<double>();
// #elif defined(__CUDA)
// fft_float = make_unique<FFT_CUDA<float>>();
// fft_double = make_unique<FFT_CUDA<double>>();
// #endif
#if defined(__ROCM)
fft_float = new FFT_RCOM<float>();
fft_float->initfft(nx_in,ny_in,nz_in);
fft_double = new FFT_RCOM<double>();
fft_double->initfft(nx_in,ny_in,nz_in);
#elif defined(__CUDA)
fft_float = make_unique<FFT_CUDA<float>>();
fft_float->initfft(nx_in,ny_in,nz_in);
fft_double = make_unique<FFT_CUDA<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
#endif
}

}
Expand Down
2 changes: 1 addition & 1 deletion source/module_basis/module_pw/module_fft/fft_bundle.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class FFT_Bundle
{
public:
FFT_Bundle(){};
~FFT_Bundle(){};
~FFT_Bundle();
/**
* @brief Constructor with device and precision.
* @param device_in device type, cpu or gpu.
Expand Down
2 changes: 1 addition & 1 deletion source/module_basis/module_pw/module_fft/fft_cpu_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ void FFT_CPU<float>::clear()
fftw_free(c_auxg);
c_auxg = nullptr;
}
if (z_auxr != nullptr)
if (c_auxr != nullptr)
{
fftw_free(c_auxr);
c_auxr = nullptr;
Expand Down
108 changes: 108 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include "fft_cuda.h"
#include "module_base/module_device/memory_op.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
namespace ModulePW
{
template <typename FPTYPE>
void FFT_CUDA<FPTYPE>::initfft(int nx_in,
int ny_in,
int nz_in)
{
this->nx = nx_in;
this->ny = ny_in;
this->nz = nz_in;
}
template <>
void FFT_CUDA<float>::setupFFT()
{
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz);

}
template <>
void FFT_CUDA<double>::setupFFT()
{
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
}
template <>
void FFT_CUDA<float>::cleanFFT()
{
if (c_handle)
{
cufftDestroy(c_handle);
c_handle = {};
}
}
template <>
void FFT_CUDA<double>::cleanFFT()
{
if (z_handle)
{
cufftDestroy(z_handle);
z_handle = {};
}
}
template <>
void FFT_CUDA<float>::clear()
{
this->cleanFFT();
if (c_auxr_3d != nullptr)
{
delmem_cd_op()(gpu_ctx, c_auxr_3d);
c_auxr_3d = nullptr;
}
}
template <>
void FFT_CUDA<double>::clear()
{
this->cleanFFT();
if (z_auxr_3d != nullptr)
{
delmem_zd_op()(gpu_ctx, z_auxr_3d);
z_auxr_3d = nullptr;
}
}

template <>
void FFT_CUDA<float>::fft3D_forward(std::complex<float>* in,
std::complex<float>* out) const
{
CHECK_CUFFT(cufftExecC2C(this->c_handle,
reinterpret_cast<cufftComplex*>(in),
reinterpret_cast<cufftComplex*>(out),
CUFFT_FORWARD));
}
template <>
void FFT_CUDA<double>::fft3D_forward(std::complex<double>* in,
std::complex<double>* out) const
{
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
reinterpret_cast<cufftDoubleComplex*>(in),
reinterpret_cast<cufftDoubleComplex*>(out),
CUFFT_FORWARD));
}
template <>
void FFT_CUDA<float>::fft3D_backward(std::complex<float>* in,
std::complex<float>* out) const
{
CHECK_CUFFT(cufftExecC2C(this->c_handle,
reinterpret_cast<cufftComplex*>(in),
reinterpret_cast<cufftComplex*>(out),
CUFFT_INVERSE));
}

template <>
void FFT_CUDA<double>::fft3D_backward(std::complex<double>* in,
std::complex<double>* out) const
{
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
reinterpret_cast<cufftDoubleComplex*>(in),
reinterpret_cast<cufftDoubleComplex*>(out),
CUFFT_INVERSE));
}
template <> std::complex<float>*
FFT_CUDA<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
template <> std::complex<double>*
FFT_CUDA<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
}// namespace ModulePW
70 changes: 70 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "fft_base.h"
#include "cufft.h"
#include "cuda_runtime.h"

#ifndef FFT_CUDA_H
#define FFT_CUDA_H
namespace ModulePW
{
template <typename FPTYPE>
class FFT_CUDA : public FFT_BASE<FPTYPE>
{
public:
FFT_CUDA(){};
~FFT_CUDA(){};

void setupFFT() override;

void clear() override;

void cleanFFT() override;

/**
* @brief Initialize the fft parameters
* @param nx_in number of grid points in x direction
* @param ny_in number of grid points in y direction
* @param nz_in number of grid points in z direction
*
*/
void initfft(int nx_in,
int ny_in,
int nz_in) override;

/**
* @brief Get the real space data
* @return real space data
*/
std::complex<FPTYPE>* get_auxr_3d_data() const override;

/**
* @brief Forward FFT in 3D
* @param in input data, complex FPTYPE
* @param out output data, complex FPTYPE
*
* This function performs the forward FFT in 3D.
*/
void fft3D_forward(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const override;
/**
* @brief Backward FFT in 3D
* @param in input data, complex FPTYPE
* @param out output data, complex FPTYPE
*
* This function performs the backward FFT in 3D.
*/
void fft3D_backward(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const override;
private:
cufftHandle c_handle = {};
cufftHandle z_handle = {};

std::complex<float>* c_auxr_3d = nullptr; // fft space
std::complex<double>* z_auxr_3d = nullptr; // fft space

};
template FFT_CUDA<float>::FFT_CUDA();
template FFT_CUDA<float>::~FFT_CUDA();
template FFT_CUDA<double>::FFT_CUDA();
template FFT_CUDA<double>::~FFT_CUDA();
} // namespace ModulePW
#endif
106 changes: 106 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_rcom.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#include "fft_rcom.h"
#include "module_base/module_device/memory_op.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
namespace ModulePW
{
template <typename FPTYPE>
void FFT_RCOM<FPTYPE>::initfft(int nx_in,
int ny_in,
int nz_in)
{
this->nx = nx_in;
this->ny = ny_in;
this->nz = nz_in;
}
template <>
void FFT_RCOM<float>::setupFFT()
{
hipfftPlan3d(&c_handle, this->nx, this->ny, this->nz, HIPFFT_C2C);
resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz);

}
template <>
void FFT_RCOM<double>::setupFFT()
{
hipfftPlan3d(&z_handle, this->nx, this->ny, this->nz, HIPFFT_Z2Z);
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
}
template <>
void FFT_RCOM<float>::cleanFFT()
{
if (c_handle)
{
hipfftDestroy(c_handle);
c_handle = {};
}
}
template <>
void FFT_RCOM<double>::cleanFFT()
{
if (z_handle)
{
hipfftDestroy(z_handle);
z_handle = {};
}
}
template <>
void FFT_RCOM<float>::clear()
{
this->cleanFFT();
if (c_auxr_3d != nullptr)
{
delmem_cd_op()(gpu_ctx, c_auxr_3d);
c_auxr_3d = nullptr;
}
}
template <>
void FFT_RCOM<double>::clear()
{
this->cleanFFT();
if (z_auxr_3d != nullptr)
{
delmem_zd_op()(gpu_ctx, z_auxr_3d);
z_auxr_3d = nullptr;
}
}
template <>
void FFT_RCOM<float>::fft3D_forward(std::complex<float>* in,
std::complex<float>* out) const
{
CHECK_CUFFT(hipfftExecC2C(this->c_handle,
reinterpret_cast<hipfftComplex*>(in),
reinterpret_cast<hipfftComplex*>(out),
HIPFFT_FORWARD));
}
template <>
void FFT_RCOM<double>::fft3D_forward(std::complex<double>* in,
std::complex<double>* out) const
{
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle,
reinterpret_cast<hipfftDoubleComplex*>(in),
reinterpret_cast<hipfftDoubleComplex*>(out),
HIPFFT_FORWARD));
}
template <>
void FFT_RCOM<float>::fft3D_backward(std::complex<float>* in,
std::complex<float>* out) const
{
CHECK_CUFFT(hipfftExecC2C(this->c_handle,
reinterpret_cast<hipfftComplex*>(in),
reinterpret_cast<hipfftComplex*>(out),
HIPFFT_BACKWARD));
}
template <>
void FFT_RCOM<double>::fft3D_backward(std::complex<double>* in,
std::complex<double>* out) const
{
CHECK_CUFFT(hipfftExecZ2Z(this->z_handle,
reinterpret_cast<hipfftDoubleComplex*>(in),
reinterpret_cast<hipfftDoubleComplex*>(out),
HIPFFT_BACKWARD));
}
template <> std::complex<float>*
FFT_RCOM<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
template <> std::complex<double>*
FFT_RCOM<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
}// namespace ModulePW
Loading
Loading