Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
c73bd57
add the basic func of the file
A-006 Nov 5, 2024
e838161
modify the Makefile
A-006 Nov 5, 2024
75e42ba
delete file
A-006 Nov 5, 2024
8d73d7e
modify the position of the new fft
A-006 Nov 5, 2024
5304827
modify the Makefile
A-006 Nov 5, 2024
4049c76
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 5, 2024
5cfd6bc
add the cpu float in the fft floder
A-006 Nov 5, 2024
9b8fb19
change the test file
A-006 Nov 5, 2024
d7f50df
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 5, 2024
9ad3c19
add the func in test
A-006 Nov 5, 2024
dfaad66
add the float fft
A-006 Nov 5, 2024
b40629d
change ft into ft1
A-006 Nov 5, 2024
965d627
add the file of the float_define and the device set
A-006 Nov 5, 2024
39e0d0c
delete the memory allocate in the ft
A-006 Nov 6, 2024
9a3cf14
add ft cuda in the file
A-006 Nov 6, 2024
7051e38
add cuda support
A-006 Nov 6, 2024
b183967
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 6, 2024
cdc3e96
change the logic of the fft_float allocate
A-006 Nov 6, 2024
f678433
add the cuda support for the ft
A-006 Nov 6, 2024
cffcbfa
Merge branch 'fft1' into fft2
A-006 Nov 6, 2024
4b9ac38
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 6, 2024
d4492a5
add the CMake
A-006 Nov 6, 2024
0b88bc8
change the logic of cuda forward and backward
A-006 Nov 6, 2024
2a45b32
add the Smart Pointer and the logic gate
A-006 Nov 7, 2024
28419a6
modify the position of the FFT
A-006 Nov 7, 2024
d29b355
change fft_bundle name
A-006 Nov 7, 2024
6cc4bac
save version of the pw_test and single version
A-006 Nov 7, 2024
2d0b5f3
fix complie bug and change the fftwf logic
A-006 Nov 8, 2024
0610758
Merge branch 'develop' into fft1
A-006 Nov 8, 2024
da26acc
add comments for the fft class
A-006 Nov 8, 2024
9b200a2
modify the fft name and add comments
A-006 Nov 8, 2024
f07a97d
modify the Makefile
A-006 Nov 8, 2024
2623735
Merge branch 'develop' into fft1
A-006 Nov 9, 2024
61bb766
Merge branch 'develop' into fft1
mohanchen Nov 10, 2024
a25be43
Merge branch 'fft1' into fft2
A-006 Nov 10, 2024
1076c82
add namespacepw
A-006 Nov 10, 2024
888c146
modify the basic line
A-006 Nov 10, 2024
5cc6650
update the initfft in the CUDA
A-006 Nov 10, 2024
9c1a22d
update the file
A-006 Nov 11, 2024
9029453
update the format
A-006 Nov 11, 2024
479b178
update the shared_ptr
A-006 Nov 11, 2024
76f0d85
Merge branch 'develop' into fft1
A-006 Nov 11, 2024
00dfee5
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 11, 2024
de67f24
add the comment and the func for cuda
A-006 Nov 13, 2024
927babf
Merge branch 'fft1' into fft2
A-006 Nov 13, 2024
731bf87
update the format
A-006 Nov 13, 2024
83a370b
Merge branch 'develop' into fft2
A-006 Nov 14, 2024
dc1c879
update the CMakefile
A-006 Nov 14, 2024
4ca95d4
modify the ft
A-006 Nov 14, 2024
cabaeee
add the clear()
A-006 Nov 14, 2024
66a479b
update the Makefile in test
A-006 Nov 15, 2024
fe1ea60
Merge branch 'develop' into fft2
A-006 Nov 15, 2024
de67db2
Merge branch 'fft2' into fft3
A-006 Nov 15, 2024
19badb0
update the template file
A-006 Nov 15, 2024
77b801a
Merge branch 'develop' into fft2
A-006 Nov 15, 2024
5a04391
update the template file
A-006 Nov 16, 2024
15d1f89
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ if (ENABLE_FLOAT_FFTW)
module_fft/fft_cpu_float.cpp
)
endif()
if (USE_CUDA)
list (APPEND FFT_SRC
module_fft/fft_cuda.cpp
)
endif()
if (USE_ROCM)
list (APPEND FFT_SRC
module_fft/fft_rcom.cpp
)
endif()

list(APPEND objects
fft.cpp
pw_basis.cpp
Expand Down
28 changes: 0 additions & 28 deletions source/module_basis/module_pw/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,34 +91,6 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
const int nrxx = this->nxy * this->nplane;
const int nsz = this->nz * this->ns;
int maxgrids = (nsz > nrxx) ? nsz : nrxx;
if (!this->mpifft)
{
// z_auxg = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
// z_auxr = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
// ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids);
// d_rspace = (double*)z_auxg;
// auxr_3d = static_cast<std::complex<double> *>(
// fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz)));
#if defined(__CUDA) || defined(__ROCM)
if (this->device == "gpu")
{
resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz);
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
}
#endif // defined(__CUDA) || defined(__ROCM)
// #if defined(__ENABLE_FLOAT_FFTW)
// if (this->precision == "single")
// {
// c_auxg = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
// c_auxr = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
// ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids);
// s_rspace = (float*)c_auxg;
// }
// #endif // defined(__ENABLE_FLOAT_FFTW)
}
else
{
}
}

void FFT::setupFFT()
Expand Down
4 changes: 0 additions & 4 deletions source/module_basis/module_pw/module_fft/fft_base.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
#include "fft_base.h"
namespace ModulePW
{
template FFT_BASE<float>::FFT_BASE();
template FFT_BASE<double>::FFT_BASE();
template FFT_BASE<float>::~FFT_BASE();
template FFT_BASE<double>::~FFT_BASE();
}
8 changes: 8 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class FFT_BASE
bool gamma_only_in,
bool xprime_in = true);

virtual __attribute__((weak))
void initfft(int nx_in,
int ny_in,
int nz_in);
/**
* @brief Setup the fft Plan and data As pure virtual function.
*
Expand Down Expand Up @@ -159,5 +163,9 @@ class FFT_BASE
int ny=0;
int nz=0;
};
template FFT_BASE<float>::FFT_BASE();
template FFT_BASE<double>::FFT_BASE();
template FFT_BASE<float>::~FFT_BASE();
template FFT_BASE<double>::~FFT_BASE();
}
#endif // FFT_BASE_H
40 changes: 25 additions & 15 deletions source/module_basis/module_pw/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
#include "fft_bundle.h"
#include "fft_cpu.h"
#include "module_base/module_device/device.h"
// #if defined(__CUDA)
// #include "fft_cuda.h"
// #endif
// #if defined(__ROCM)
// #include "fft_rcom.h"
// #endif

#if defined(__CUDA)
#include "fft_cuda.h"
#endif
#if defined(__ROCM)
#include "fft_rcom.h"
#endif

template<typename FFT_BASE, typename... Args>
std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
Expand All @@ -21,7 +22,10 @@ void FFT_Bundle::setfft(std::string device_in,std::string precision_in)
this->device = device_in;
this->precision = precision_in;
}

FFT_Bundle::~FFT_Bundle()
{
this->clear();
}
void FFT_Bundle::initfft(int nx_in,
int ny_in,
int nz_in,
Expand Down Expand Up @@ -81,15 +85,21 @@ void FFT_Bundle::initfft(int nx_in,
xprime_in);
}
}
if (device=="gpu")
else if (device=="gpu")
{
// #if defined(__ROCM)
// fft_float = new FFT_RCOM<float>();
// fft_double = new FFT_RCOM<double>();
// #elif defined(__CUDA)
// fft_float = make_unique<FFT_CUDA<float>>();
// fft_double = make_unique<FFT_CUDA<double>>();
// #endif
float_flag=true;
double_flag=true;
#if defined(__ROCM)
fft_float = make_unique<FFT_RCOM<float>>;
fft_float->initfft(nx_in,ny_in,nz_in);
fft_double = make_unique<FFT_RCOM<double>>;
fft_double->initfft(nx_in,ny_in,nz_in);
#elif defined(__CUDA)
fft_float = make_unique<FFT_CUDA<float>>();
fft_float->initfft(nx_in,ny_in,nz_in);
fft_double = make_unique<FFT_CUDA<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
#endif
}

}
Expand Down
2 changes: 1 addition & 1 deletion source/module_basis/module_pw/module_fft/fft_bundle.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class FFT_Bundle
{
public:
FFT_Bundle(){};
~FFT_Bundle(){};
~FFT_Bundle();
/**
* @brief Constructor with device and precision.
* @param device_in device type, cpu or gpu.
Expand Down
5 changes: 0 additions & 5 deletions source/module_basis/module_pw/module_fft/fft_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,4 @@ template <> std::complex<double>*
FFT_CPU<double>::get_auxr_data() const {return z_auxr;}
template <> std::complex<double>*
FFT_CPU<double>::get_auxg_data() const {return z_auxg;}

template FFT_CPU<float>::FFT_CPU();
template FFT_CPU<float>::~FFT_CPU();
template FFT_CPU<double>::FFT_CPU();
template FFT_CPU<double>::~FFT_CPU();
}
58 changes: 31 additions & 27 deletions source/module_basis/module_pw/module_fft/fft_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class FFT_CPU : public FFT_BASE<FPTYPE>
* @param gamma_only_in whether only gamma point is used.
* @param xprime_in whether xprime is used.
*/
__attribute__((weak))

void initfft(int nx_in,
int ny_in,
int nz_in,
Expand All @@ -44,10 +44,10 @@ class FFT_CPU : public FFT_BASE<FPTYPE>
int nproc_in,
bool gamma_only_in,
bool xprime_in = true) override;

__attribute__((weak))
void setupFFT() override;

// void initplan(const unsigned int& flag = 0);
__attribute__((weak))
void cleanFFT() override;

Expand Down Expand Up @@ -106,31 +106,31 @@ class FFT_CPU : public FFT_BASE<FPTYPE>
void clearfft(fftw_plan& plan);
void clearfft(fftwf_plan& plan);

fftw_plan planzfor = NULL;
fftw_plan planzbac = NULL;
fftw_plan planxfor1 = NULL;
fftw_plan planxbac1 = NULL;
fftw_plan planxfor2 = NULL;
fftw_plan planxbac2 = NULL;
fftw_plan planyfor = NULL;
fftw_plan planybac = NULL;
fftw_plan planxr2c = NULL;
fftw_plan planxc2r = NULL;
fftw_plan planyr2c = NULL;
fftw_plan planyc2r = NULL;

fftwf_plan planfzfor = NULL;
fftwf_plan planfzbac = NULL;
fftwf_plan planfxfor1= NULL;
fftwf_plan planfxbac1= NULL;
fftwf_plan planfxfor2= NULL;
fftwf_plan planfxbac2= NULL;
fftwf_plan planfyfor = NULL;
fftwf_plan planfybac = NULL;
fftwf_plan planfxr2c = NULL;
fftwf_plan planfxc2r = NULL;
fftwf_plan planfyr2c = NULL;
fftwf_plan planfyc2r = NULL;
fftw_plan planzfor = nullptr;
fftw_plan planzbac = nullptr;
fftw_plan planxfor1 = nullptr;
fftw_plan planxbac1 = nullptr;
fftw_plan planxfor2 = nullptr;
fftw_plan planxbac2 = nullptr;
fftw_plan planyfor = nullptr;
fftw_plan planybac = nullptr;
fftw_plan planxr2c = nullptr;
fftw_plan planxc2r = nullptr;
fftw_plan planyr2c = nullptr;
fftw_plan planyc2r = nullptr;

fftwf_plan planfzfor = nullptr;
fftwf_plan planfzbac = nullptr;
fftwf_plan planfxfor1= nullptr;
fftwf_plan planfxbac1= nullptr;
fftwf_plan planfxfor2= nullptr;
fftwf_plan planfxbac2= nullptr;
fftwf_plan planfyfor = nullptr;
fftwf_plan planfybac = nullptr;
fftwf_plan planfxr2c = nullptr;
fftwf_plan planfxc2r = nullptr;
fftwf_plan planfyr2c = nullptr;
fftwf_plan planfyc2r = nullptr;

std::complex<float>*c_auxg = nullptr;
std::complex<float>*c_auxr = nullptr; // fft space
Expand Down Expand Up @@ -169,5 +169,9 @@ class FFT_CPU : public FFT_BASE<FPTYPE>
*/
int fft_mode = 0;
};
template FFT_CPU<float>::FFT_CPU();
template FFT_CPU<float>::~FFT_CPU();
template FFT_CPU<double>::FFT_CPU();
template FFT_CPU<double>::~FFT_CPU();
}
#endif // FFT_CPU_H
30 changes: 15 additions & 15 deletions source/module_basis/module_pw/module_fft/fft_cpu_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,30 +267,30 @@ void FFT_CPU<float>::setupFFT()
}

template <>
void FFT_CPU<float>::clearfft(fftw_plan& plan)
void FFT_CPU<float>::clearfft(fftwf_plan& plan)
{
if (plan)
{
fftw_destroy_plan(plan);
fftwf_destroy_plan(plan);
plan = nullptr;
}
}

template <>
void FFT_CPU<float>::cleanFFT()
{
clearfft(planzfor);
clearfft(planzbac);
clearfft(planxfor1);
clearfft(planxbac1);
clearfft(planxfor2);
clearfft(planxbac2);
clearfft(planyfor);
clearfft(planybac);
clearfft(planxr2c);
clearfft(planxc2r);
clearfft(planyr2c);
clearfft(planyc2r);
clearfft(planfzfor);
clearfft(planfzbac);
clearfft(planfxfor1);
clearfft(planfxbac1);
clearfft(planfxfor2);
clearfft(planfxbac2);
clearfft(planfyfor);
clearfft(planfybac);
clearfft(planfxr2c);
clearfft(planfxc2r);
clearfft(planfyr2c);
clearfft(planfyc2r);
}


Expand All @@ -303,7 +303,7 @@ void FFT_CPU<float>::clear()
fftw_free(c_auxg);
c_auxg = nullptr;
}
if (z_auxr != nullptr)
if (c_auxr != nullptr)
{
fftw_free(c_auxr);
c_auxr = nullptr;
Expand Down
Loading
Loading