Skip to content

Commit 7c46674

Browse files
A-006pre-commit-ci-lite[bot]mohanchen
authored
Refactor:Replace the current fft with templates and polymorphism (#5410)
* add the basic func of the file * modify the Makefile * delete file * modify the position of the new fft * modify the Makefile * [pre-commit.ci lite] apply automatic fixes * add the cpu float in the fft floder * change the test file * [pre-commit.ci lite] apply automatic fixes * add the func in test * add the float fft * change ft into ft1 * add the file of the float_define and the device set * delete the memory allocate in the ft * [pre-commit.ci lite] apply automatic fixes * add the Smart Pointer and the logic gate * modify the position of the FFT * change fft_bundle name * save version of the pw_test and single version * fix complie bug and change the fftwf logic * add comments for the fft class * modify the fft name and add comments * modify the Makefile * update the file * update the format * update the shared_ptr * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Co-authored-by: Mohan Chen <[email protected]>
1 parent 8d6e593 commit 7c46674

File tree

27 files changed

+1875
-114
lines changed

27 files changed

+1875
-114
lines changed

source/Makefile.Objects

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ VPATH=./src_global:\
2727
./module_base/module_mixing:\
2828
./module_md:\
2929
./module_basis/module_pw:\
30+
./module_basis/module_pw/module_fft:\
3031
./module_esolver:\
3132
./module_hsolver:\
3233
./module_hsolver/kernels:\
@@ -168,7 +169,6 @@ OBJS_BASE=abfs-vector3_order.o\
168169
memory_op.o\
169170
device.o\
170171

171-
172172
OBJS_CELL=atom_pseudo.o\
173173
atom_spec.o\
174174
pseudo.o\
@@ -414,6 +414,9 @@ OBJS_PSI_INITIALIZER=psi_initializer.o\
414414
psi_initializer_nao_random.o\
415415

416416
OBJS_PW=fft.o\
417+
fft_bundle.o\
418+
fft_base.o\
419+
fft_cpu.o\
417420
pw_basis.o\
418421
pw_basis_k.o\
419422
pw_basis_sup.o\

source/module_base/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ list (APPEND LIBM_SRC
66
libm/sincos.cpp
77
)
88
endif()
9-
109
add_library(
1110
base
1211
OBJECT

source/module_basis/module_pw/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
if (ENABLE_FLOAT_FFTW)
2+
list (APPEND FFT_SRC
3+
module_fft/fft_cpu_float.cpp
4+
)
5+
endif()
16
list(APPEND objects
27
fft.cpp
38
pw_basis.cpp
@@ -10,6 +15,10 @@ list(APPEND objects
1015
pw_init.cpp
1116
pw_transform.cpp
1217
pw_transform_k.cpp
18+
module_fft/fft_base.cpp
19+
module_fft/fft_bundle.cpp
20+
module_fft/fft_cpu.cpp
21+
${FFT_SRC}
1322
)
1423

1524
add_library(

source/module_basis/module_pw/fft.cpp

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,11 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
7272
this->fftny = this->ny = ny_in;
7373
if (this->gamma_only)
7474
{
75-
if (xprime)
75+
if (xprime) {
7676
this->fftnx = int(nx / 2) + 1;
77-
else
77+
} else {
7878
this->fftny = int(ny / 2) + 1;
79+
}
7980
}
8081
this->nz = nz_in;
8182
this->ns = ns_in;
@@ -92,10 +93,10 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
9293
int maxgrids = (nsz > nrxx) ? nsz : nrxx;
9394
if (!this->mpifft)
9495
{
95-
z_auxg = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
96-
z_auxr = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
97-
ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids);
98-
d_rspace = (double*)z_auxg;
96+
// z_auxg = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
97+
// z_auxr = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
98+
// ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids);
99+
// d_rspace = (double*)z_auxg;
99100
// auxr_3d = static_cast<std::complex<double> *>(
100101
// fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz)));
101102
#if defined(__CUDA) || defined(__ROCM)
@@ -105,15 +106,15 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
105106
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
106107
}
107108
#endif // defined(__CUDA) || defined(__ROCM)
108-
#if defined(__ENABLE_FLOAT_FFTW)
109-
if (this->precision == "single")
110-
{
111-
c_auxg = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
112-
c_auxr = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
113-
ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids);
114-
s_rspace = (float*)c_auxg;
115-
}
116-
#endif // defined(__ENABLE_FLOAT_FFTW)
109+
// #if defined(__ENABLE_FLOAT_FFTW)
110+
// if (this->precision == "single")
111+
// {
112+
// c_auxg = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
113+
// c_auxr = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
114+
// ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids);
115+
// s_rspace = (float*)c_auxg;
116+
// }
117+
// #endif // defined(__ENABLE_FLOAT_FFTW)
117118
}
118119
else
119120
{
@@ -353,62 +354,62 @@ void FFT::cleanFFT()
353354
if (planzfor)
354355
{
355356
fftw_destroy_plan(planzfor);
356-
planzfor = NULL;
357+
planzfor = nullptr;
357358
}
358359
if (planzbac)
359360
{
360361
fftw_destroy_plan(planzbac);
361-
planzbac = NULL;
362+
planzbac = nullptr;
362363
}
363364
if (planxfor1)
364365
{
365366
fftw_destroy_plan(planxfor1);
366-
planxfor1 = NULL;
367+
planxfor1 = nullptr;
367368
}
368369
if (planxbac1)
369370
{
370371
fftw_destroy_plan(planxbac1);
371-
planxbac1 = NULL;
372+
planxbac1 = nullptr;
372373
}
373374
if (planxfor2)
374375
{
375376
fftw_destroy_plan(planxfor2);
376-
planxfor2 = NULL;
377+
planxfor2 = nullptr;
377378
}
378379
if (planxbac2)
379380
{
380381
fftw_destroy_plan(planxbac2);
381-
planxbac2 = NULL;
382+
planxbac2 = nullptr;
382383
}
383384
if (planyfor)
384385
{
385386
fftw_destroy_plan(planyfor);
386-
planyfor = NULL;
387+
planyfor = nullptr;
387388
}
388389
if (planybac)
389390
{
390391
fftw_destroy_plan(planybac);
391-
planybac = NULL;
392+
planybac = nullptr;
392393
}
393394
if (planxr2c)
394395
{
395396
fftw_destroy_plan(planxr2c);
396-
planxr2c = NULL;
397+
planxr2c = nullptr;
397398
}
398399
if (planxc2r)
399400
{
400401
fftw_destroy_plan(planxc2r);
401-
planxc2r = NULL;
402+
planxc2r = nullptr;
402403
}
403404
if (planyr2c)
404405
{
405406
fftw_destroy_plan(planyr2c);
406-
planyr2c = NULL;
407+
planyr2c = nullptr;
407408
}
408409
if (planyc2r)
409410
{
410411
fftw_destroy_plan(planyc2r);
411-
planyc2r = NULL;
412+
planyc2r = nullptr;
412413
}
413414

414415
// fftw_destroy_plan(this->plan3dforward);
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include "fft_base.h"
2+
namespace ModulePW
3+
{
4+
template FFT_BASE<float>::FFT_BASE();
5+
template FFT_BASE<double>::FFT_BASE();
6+
template FFT_BASE<float>::~FFT_BASE();
7+
template FFT_BASE<double>::~FFT_BASE();
8+
}
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#include <complex>
2+
#include <string>
3+
#include "fftw3.h"
4+
#ifndef FFT_BASE_H
5+
#define FFT_BASE_H
6+
namespace ModulePW
7+
{
8+
template <typename FPTYPE>
9+
class FFT_BASE
10+
{
11+
public:
12+
13+
FFT_BASE(){};
14+
virtual ~FFT_BASE(){};
15+
16+
/**
17+
* @brief Initialize the fft parameters As virtual function.
18+
*
19+
* The function is used to initialize the fft parameters.
20+
*/
21+
virtual __attribute__((weak))
22+
void initfft(int nx_in,
23+
int ny_in,
24+
int nz_in,
25+
int lixy_in,
26+
int rixy_in,
27+
int ns_in,
28+
int nplane_in,
29+
int nproc_in,
30+
bool gamma_only_in,
31+
bool xprime_in = true);
32+
33+
/**
34+
* @brief Setup the fft Plan and data As pure virtual function.
35+
*
36+
* The function is set as pure virtual function.In order to
37+
* override the function in the derived class.In the derived
38+
* class, the function is used to setup the fft Plan and data.
39+
*/
40+
virtual void setupFFT()=0;
41+
42+
/**
43+
* @brief Clean the fft Plan As pure virtual function.
44+
*
45+
* The function is set as pure virtual function.In order to
46+
* override the function in the derived class.In the derived
47+
* class, the function is used to clean the fft Plan.
48+
*/
49+
virtual void cleanFFT()=0;
50+
51+
/**
52+
* @brief Clear the fft data As pure virtual function.
53+
*
54+
* The function is set as pure virtual function.In order to
55+
* override the function in the derived class.In the derived
56+
* class, the function is used to clear the fft data.
57+
*/
58+
virtual void clear()=0;
59+
60+
/**
61+
* @brief Get the real space data in cpu-like fft
62+
*
63+
* The function is used to get the real space data.While the
64+
* FFT_BASE is an abstract class,the function will be override,
65+
* The attribute weak is used to avoid define the function.
66+
*/
67+
virtual __attribute__((weak))
68+
FPTYPE* get_rspace_data() const;
69+
70+
virtual __attribute__((weak))
71+
std::complex<FPTYPE>* get_auxr_data() const;
72+
73+
virtual __attribute__((weak))
74+
std::complex<FPTYPE>* get_auxg_data() const;
75+
76+
/**
77+
* @brief Get the auxiliary real space data in 3D
78+
*
79+
* The function is used to get the auxiliary real space data in 3D.
80+
* While the FFT_BASE is an abstract class,the function will be override,
81+
* The attribute weak is used to avoid define the function.
82+
*/
83+
virtual __attribute__((weak))
84+
std::complex<FPTYPE>* get_auxr_3d_data() const;
85+
86+
//forward fft in x-y direction
87+
88+
/**
89+
* @brief Forward FFT in x-y direction
90+
* @param in input data
91+
* @param out output data
92+
*
93+
* This function performs the forward FFT in the x-y direction.
94+
* It involves two axes, x and y. The FFT is applied multiple times
95+
* along the left and right boundaries in the primary direction(which is
96+
* determined by the xprime flag).Notably, the Y axis operates in
97+
* "many-many-FFT" mode.
98+
*/
99+
virtual __attribute__((weak))
100+
void fftxyfor(std::complex<FPTYPE>* in,
101+
std::complex<FPTYPE>* out) const;
102+
103+
virtual __attribute__((weak))
104+
void fftxybac(std::complex<FPTYPE>* in,
105+
std::complex<FPTYPE>* out) const;
106+
107+
/**
108+
* @brief Forward FFT in z direction
109+
* @param in input data
110+
* @param out output data
111+
*
112+
* This function performs the forward FFT in the z direction.
113+
* It involves only one axis, z. The FFT is applied only once.
114+
* Notably, the Z axis operates in many FFT with nz*ns.
115+
*/
116+
virtual __attribute__((weak))
117+
void fftzfor(std::complex<FPTYPE>* in,
118+
std::complex<FPTYPE>* out) const;
119+
120+
virtual __attribute__((weak))
121+
void fftzbac(std::complex<FPTYPE>* in,
122+
std::complex<FPTYPE>* out) const;
123+
124+
/**
125+
* @brief Forward FFT in x-y direction with real to complex
126+
* @param in input data, real type
127+
* @param out output data, complex type
128+
*
129+
* This function performs the forward FFT in the x-y direction
130+
* with real to complex.There is no difference between fftxyfor.
131+
*/
132+
virtual __attribute__((weak))
133+
void fftxyr2c(FPTYPE* in,
134+
std::complex<FPTYPE>* out) const;
135+
136+
virtual __attribute__((weak))
137+
void fftxyc2r(std::complex<FPTYPE>* in,
138+
FPTYPE* out) const;
139+
140+
/**
141+
* @brief Forward FFT in 3D
142+
* @param in input data
143+
* @param out output data
144+
*
145+
* This function performs the forward FFT for gpu-like fft.
146+
* It involves three axes, x, y, and z. The FFT is applied multiple times
147+
* for fft3D_forward.
148+
*/
149+
virtual __attribute__((weak))
150+
void fft3D_forward(std::complex<FPTYPE>* in,
151+
std::complex<FPTYPE>* out) const;
152+
153+
virtual __attribute__((weak))
154+
void fft3D_backward(std::complex<FPTYPE>* in,
155+
std::complex<FPTYPE>* out) const;
156+
157+
protected:
158+
int nx=0;
159+
int ny=0;
160+
int nz=0;
161+
};
162+
}
163+
#endif // FFT_BASE_H

0 commit comments

Comments
 (0)