Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c73bd57
add the basic func of the file
A-006 Nov 5, 2024
e838161
modify the Makefile
A-006 Nov 5, 2024
75e42ba
delete file
A-006 Nov 5, 2024
8d73d7e
modify the position of the new fft
A-006 Nov 5, 2024
5304827
modify the Makefile
A-006 Nov 5, 2024
4049c76
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 5, 2024
5cfd6bc
add the cpu float in the fft floder
A-006 Nov 5, 2024
9b8fb19
change the test file
A-006 Nov 5, 2024
d7f50df
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 5, 2024
9ad3c19
add the func in test
A-006 Nov 5, 2024
dfaad66
add the float fft
A-006 Nov 5, 2024
b40629d
change ft into ft1
A-006 Nov 5, 2024
965d627
add the file of the float_define and the device set
A-006 Nov 5, 2024
39e0d0c
delete the memory allocate in the ft
A-006 Nov 6, 2024
b183967
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 6, 2024
2a45b32
add the Smart Pointer and the logic gate
A-006 Nov 7, 2024
28419a6
modify the position of the FFT
A-006 Nov 7, 2024
d29b355
change fft_bundle name
A-006 Nov 7, 2024
6cc4bac
save version of the pw_test and single version
A-006 Nov 7, 2024
2d0b5f3
fix complie bug and change the fftwf logic
A-006 Nov 8, 2024
0610758
Merge branch 'develop' into fft1
A-006 Nov 8, 2024
da26acc
add comments for the fft class
A-006 Nov 8, 2024
9b200a2
modify the fft name and add comments
A-006 Nov 8, 2024
f07a97d
modify the Makefile
A-006 Nov 8, 2024
2623735
Merge branch 'develop' into fft1
A-006 Nov 9, 2024
61bb766
Merge branch 'develop' into fft1
mohanchen Nov 10, 2024
9c1a22d
update the file
A-006 Nov 11, 2024
9029453
update the format
A-006 Nov 11, 2024
479b178
update the shared_ptr
A-006 Nov 11, 2024
76f0d85
Merge branch 'develop' into fft1
A-006 Nov 11, 2024
00dfee5
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ VPATH=./src_global:\
./module_base/module_mixing:\
./module_md:\
./module_basis/module_pw:\
./module_basis/module_pw/module_fft:\
./module_esolver:\
./module_hsolver:\
./module_hsolver/kernels:\
Expand Down Expand Up @@ -168,7 +169,6 @@ OBJS_BASE=abfs-vector3_order.o\
memory_op.o\
device.o\


OBJS_CELL=atom_pseudo.o\
atom_spec.o\
pseudo.o\
Expand Down Expand Up @@ -414,6 +414,9 @@ OBJS_PSI_INITIALIZER=psi_initializer.o\
psi_initializer_nao_random.o\

OBJS_PW=fft.o\
fft_bundle.o\
fft_base.o\
fft_cpu.o\
pw_basis.o\
pw_basis_k.o\
pw_basis_sup.o\
Expand Down
1 change: 0 additions & 1 deletion source/module_base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ list (APPEND LIBM_SRC
libm/sincos.cpp
)
endif()

add_library(
base
OBJECT
Expand Down
9 changes: 9 additions & 0 deletions source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
if (ENABLE_FLOAT_FFTW)
list (APPEND FFT_SRC
module_fft/fft_cpu_float.cpp
)
endif()
list(APPEND objects
fft.cpp
pw_basis.cpp
Expand All @@ -10,6 +15,10 @@ list(APPEND objects
pw_init.cpp
pw_transform.cpp
pw_transform_k.cpp
module_fft/fft_base.cpp
module_fft/fft_bundle.cpp
module_fft/fft_cpu.cpp
${FFT_SRC}
)

add_library(
Expand Down
55 changes: 28 additions & 27 deletions source/module_basis/module_pw/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
this->fftny = this->ny = ny_in;
if (this->gamma_only)
{
if (xprime)
if (xprime) {
this->fftnx = int(nx / 2) + 1;
else
} else {
this->fftny = int(ny / 2) + 1;
}
}
this->nz = nz_in;
this->ns = ns_in;
Expand All @@ -92,10 +93,10 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
int maxgrids = (nsz > nrxx) ? nsz : nrxx;
if (!this->mpifft)
{
z_auxg = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
z_auxr = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids);
d_rspace = (double*)z_auxg;
// z_auxg = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
// z_auxr = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
// ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids);
// d_rspace = (double*)z_auxg;
// auxr_3d = static_cast<std::complex<double> *>(
// fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz)));
#if defined(__CUDA) || defined(__ROCM)
Expand All @@ -105,15 +106,15 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
}
#endif // defined(__CUDA) || defined(__ROCM)
#if defined(__ENABLE_FLOAT_FFTW)
if (this->precision == "single")
{
c_auxg = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
c_auxr = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids);
s_rspace = (float*)c_auxg;
}
#endif // defined(__ENABLE_FLOAT_FFTW)
// #if defined(__ENABLE_FLOAT_FFTW)
// if (this->precision == "single")
// {
// c_auxg = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
// c_auxr = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
// ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids);
// s_rspace = (float*)c_auxg;
// }
// #endif // defined(__ENABLE_FLOAT_FFTW)
}
else
{
Expand Down Expand Up @@ -353,62 +354,62 @@ void FFT::cleanFFT()
if (planzfor)
{
fftw_destroy_plan(planzfor);
planzfor = NULL;
planzfor = nullptr;
}
if (planzbac)
{
fftw_destroy_plan(planzbac);
planzbac = NULL;
planzbac = nullptr;
}
if (planxfor1)
{
fftw_destroy_plan(planxfor1);
planxfor1 = NULL;
planxfor1 = nullptr;
}
if (planxbac1)
{
fftw_destroy_plan(planxbac1);
planxbac1 = NULL;
planxbac1 = nullptr;
}
if (planxfor2)
{
fftw_destroy_plan(planxfor2);
planxfor2 = NULL;
planxfor2 = nullptr;
}
if (planxbac2)
{
fftw_destroy_plan(planxbac2);
planxbac2 = NULL;
planxbac2 = nullptr;
}
if (planyfor)
{
fftw_destroy_plan(planyfor);
planyfor = NULL;
planyfor = nullptr;
}
if (planybac)
{
fftw_destroy_plan(planybac);
planybac = NULL;
planybac = nullptr;
}
if (planxr2c)
{
fftw_destroy_plan(planxr2c);
planxr2c = NULL;
planxr2c = nullptr;
}
if (planxc2r)
{
fftw_destroy_plan(planxc2r);
planxc2r = NULL;
planxc2r = nullptr;
}
if (planyr2c)
{
fftw_destroy_plan(planyr2c);
planyr2c = NULL;
planyr2c = nullptr;
}
if (planyc2r)
{
fftw_destroy_plan(planyc2r);
planyc2r = NULL;
planyc2r = nullptr;
}

// fftw_destroy_plan(this->plan3dforward);
Expand Down
8 changes: 8 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include "fft_base.h"
namespace ModulePW
{
template FFT_BASE<float>::FFT_BASE();
template FFT_BASE<double>::FFT_BASE();
template FFT_BASE<float>::~FFT_BASE();
template FFT_BASE<double>::~FFT_BASE();
}
163 changes: 163 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#include <complex>
#include <string>
#include "fftw3.h"
#ifndef FFT_BASE_H
#define FFT_BASE_H
namespace ModulePW
{
template <typename FPTYPE>
class FFT_BASE
{
public:

FFT_BASE(){};
virtual ~FFT_BASE(){};

/**
* @brief Initialize the fft parameters As virtual function.
*
* The function is used to initialize the fft parameters.
*/
virtual __attribute__((weak))
void initfft(int nx_in,
int ny_in,
int nz_in,
int lixy_in,
int rixy_in,
int ns_in,
int nplane_in,
int nproc_in,
bool gamma_only_in,
bool xprime_in = true);

/**
* @brief Setup the fft Plan and data As pure virtual function.
*
* The function is set as pure virtual function.In order to
* override the function in the derived class.In the derived
* class, the function is used to setup the fft Plan and data.
*/
virtual void setupFFT()=0;

/**
* @brief Clean the fft Plan As pure virtual function.
*
* The function is set as pure virtual function.In order to
* override the function in the derived class.In the derived
* class, the function is used to clean the fft Plan.
*/
virtual void cleanFFT()=0;

/**
* @brief Clear the fft data As pure virtual function.
*
* The function is set as pure virtual function.In order to
* override the function in the derived class.In the derived
* class, the function is used to clear the fft data.
*/
virtual void clear()=0;

/**
* @brief Get the real space data in cpu-like fft
*
* The function is used to get the real space data.While the
* FFT_BASE is an abstract class,the function will be override,
* The attribute weak is used to avoid define the function.
*/
virtual __attribute__((weak))
FPTYPE* get_rspace_data() const;

virtual __attribute__((weak))
std::complex<FPTYPE>* get_auxr_data() const;

virtual __attribute__((weak))
std::complex<FPTYPE>* get_auxg_data() const;

/**
* @brief Get the auxiliary real space data in 3D
*
* The function is used to get the auxiliary real space data in 3D.
* While the FFT_BASE is an abstract class,the function will be override,
* The attribute weak is used to avoid define the function.
*/
virtual __attribute__((weak))
std::complex<FPTYPE>* get_auxr_3d_data() const;

//forward fft in x-y direction

/**
* @brief Forward FFT in x-y direction
* @param in input data
* @param out output data
*
* This function performs the forward FFT in the x-y direction.
* It involves two axes, x and y. The FFT is applied multiple times
* along the left and right boundaries in the primary direction(which is
* determined by the xprime flag).Notably, the Y axis operates in
* "many-many-FFT" mode.
*/
virtual __attribute__((weak))
void fftxyfor(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak))
void fftxybac(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

/**
* @brief Forward FFT in z direction
* @param in input data
* @param out output data
*
* This function performs the forward FFT in the z direction.
* It involves only one axis, z. The FFT is applied only once.
* Notably, the Z axis operates in many FFT with nz*ns.
*/
virtual __attribute__((weak))
void fftzfor(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak))
void fftzbac(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

/**
* @brief Forward FFT in x-y direction with real to complex
* @param in input data, real type
* @param out output data, complex type
*
* This function performs the forward FFT in the x-y direction
* with real to complex.There is no difference between fftxyfor.
*/
virtual __attribute__((weak))
void fftxyr2c(FPTYPE* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak))
void fftxyc2r(std::complex<FPTYPE>* in,
FPTYPE* out) const;

/**
* @brief Forward FFT in 3D
* @param in input data
* @param out output data
*
* This function performs the forward FFT for gpu-like fft.
* It involves three axes, x, y, and z. The FFT is applied multiple times
* for fft3D_forward.
*/
virtual __attribute__((weak))
void fft3D_forward(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak))
void fft3D_backward(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

protected:
int nx=0;
int ny=0;
int nz=0;
};
}
#endif // FFT_BASE_H
Loading
Loading