Skip to content

Commit 5cfd6bc

Browse files
committed
add the cpu float in the fft floder
1 parent 4049c76 commit 5cfd6bc

File tree

6 files changed

+928
-224
lines changed

6 files changed

+928
-224
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ OBJS_BASE=abfs-vector3_order.o\
169169
device.o\
170170
fft_temp.o\
171171
fft_base.o\
172+
fft_cpu.o\
172173

173174
OBJS_CELL=atom_pseudo.o\
174175
atom_spec.o\

source/module_base/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ list (APPEND LIBM_SRC
66
libm/sincos.cpp
77
)
88
endif()
9-
9+
if (ENABLE_FLOAT_FFTW)
10+
list (APPEND FFT_SRC
11+
module_fft/fftw_float.cpp
12+
)
13+
endif()
1014
add_library(
1115
base
1216
OBJECT
@@ -60,7 +64,9 @@ add_library(
6064
module_mixing/broyden_mixing.cpp
6165
module_fft/fft_base.cpp
6266
module_fft/fft_temp.cpp
67+
module_fft/fft_cpu.cpp
6368
${LIBM_SRC}
69+
${FFT_SRC}
6470
)
6571

6672
add_subdirectory(module_container)
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
#include "fft_cpu.h"
2+
#include "fftw3.h"
3+
#if defined(__FFTW3_MPI) && defined(__MPI)
4+
#include <fftw3-mpi.h>
5+
//#include "fftw3-mpi_mkl.h"
6+
#endif
7+
8+
template <>
9+
FFT_CPU<double>::FFT_CPU()
10+
{
11+
12+
}
13+
template <>
14+
FFT_CPU<double>::~FFT_CPU()
15+
{
16+
17+
}
18+
19+
template <>
20+
void FFT_CPU<double>::setupFFT()
21+
{
22+
23+
unsigned int flag = FFTW_ESTIMATE;
24+
switch (this->fft_mode)
25+
{
26+
case 0:
27+
flag = FFTW_ESTIMATE;
28+
break;
29+
case 1:
30+
flag = FFTW_MEASURE;
31+
break;
32+
case 2:
33+
flag = FFTW_PATIENT;
34+
break;
35+
case 3:
36+
flag = FFTW_EXHAUSTIVE;
37+
break;
38+
default:
39+
break;
40+
}
41+
if (!this->mpifft)
42+
{
43+
z_auxg = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids);
44+
z_auxr = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * this->maxgrids);
45+
d_rspace = (double*)z_auxg;
46+
this->planzfor = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz,
47+
(fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_FORWARD, flag);
48+
49+
this->planzbac = fftw_plan_many_dft(1, &this->nz, this->ns, (fftw_complex*)z_auxg, &this->nz, 1, this->nz,
50+
(fftw_complex*)z_auxg, &this->nz, 1, this->nz, FFTW_BACKWARD, flag);
51+
52+
//---------------------------------------------------------
53+
// 2 D - XY
54+
//---------------------------------------------------------
55+
// 1D+1D is much faster than 2D FFT!
56+
// in-place fft is better for c2c and out-of-place fft is better for c2r
57+
int* embed = nullptr;
58+
int npy = this->nplane * this->ny;
59+
if (this->xprime)
60+
{
61+
this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1,
62+
(fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_FORWARD, flag);
63+
this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,this->nplane, 1,
64+
(fftw_complex*)z_auxr, embed,this->nplane, 1, FFTW_BACKWARD, flag);
65+
if (this->gamma_only)
66+
{
67+
this->planxr2c = fftw_plan_many_dft_r2c(1, &this->nx, npy, d_rspace, embed, npy, 1, (fftw_complex*)z_auxr,
68+
embed, npy, 1, flag);
69+
this->planxc2r = fftw_plan_many_dft_c2r(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1, d_rspace,
70+
embed, npy, 1, flag);
71+
}
72+
else
73+
{
74+
this->planxfor1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1,
75+
(fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag);
76+
this->planxbac1 = fftw_plan_many_dft(1, &this->nx, npy, (fftw_complex*)z_auxr, embed, npy, 1,
77+
(fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag);
78+
}
79+
}
80+
else
81+
{
82+
this->planxfor1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy,
83+
1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag);
84+
this->planxbac1 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->lixy + 1), (fftw_complex*)z_auxr, embed, npy,
85+
1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag);
86+
if (this->gamma_only)
87+
{
88+
this->planyr2c = fftw_plan_many_dft_r2c(1, &this->ny, this->nplane, d_rspace, embed, this->nplane, 1,
89+
(fftw_complex*)z_auxr, embed, this->nplane, 1, flag);
90+
this->planyc2r = fftw_plan_many_dft_c2r(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed,
91+
this->nplane, 1, d_rspace, embed, this->nplane, 1, flag);
92+
}
93+
else
94+
{
95+
96+
this->planxfor2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed,
97+
npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_FORWARD, flag);
98+
this->planxbac2 = fftw_plan_many_dft(1, &this->nx, this->nplane * (this->ny - this->rixy), (fftw_complex*)z_auxr, embed,
99+
npy, 1, (fftw_complex*)z_auxr, embed, npy, 1, FFTW_BACKWARD, flag);
100+
this->planyfor = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane,
101+
1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_FORWARD, flag);
102+
this->planybac = fftw_plan_many_dft(1, &this->ny, this->nplane, (fftw_complex*)z_auxr, embed, this->nplane,
103+
1, (fftw_complex*)z_auxr, embed, this->nplane, 1, FFTW_BACKWARD, flag);
104+
}
105+
}
106+
}
107+
#if defined(__FFTW3_MPI) && defined(__MPI)
108+
else
109+
{
110+
// this->initplan_mpi();
111+
// if (this->precision == "single") {
112+
// this->initplanf_mpi();
113+
// }
114+
}
115+
#endif
116+
return;
117+
}
118+
template <>
119+
void FFT_CPU<double>::initfftmode(int fft_mode_in)
120+
{
121+
this->fft_mode = fft_mode_in;
122+
}
123+
124+
template <>
125+
void FFT_CPU<double>::clearfft(fftw_plan& plan)
126+
{
127+
if (plan)
128+
{
129+
fftw_destroy_plan(plan);
130+
plan = NULL;
131+
}
132+
}
133+
134+
template <>
135+
void FFT_CPU<double>::cleanFFT()
136+
{
137+
printf("in the double cleanFFT\n");
138+
clearfft(planzfor);
139+
clearfft(planzbac);
140+
clearfft(planxfor1);
141+
clearfft(planxbac1);
142+
clearfft(planxfor2);
143+
clearfft(planxbac2);
144+
clearfft(planyfor);
145+
clearfft(planybac);
146+
clearfft(planxr2c);
147+
clearfft(planxc2r);
148+
clearfft(planyr2c);
149+
clearfft(planyc2r);
150+
}
151+
152+
153+
template <>
154+
void FFT_CPU<double>::clear()
155+
{
156+
this->cleanFFT();
157+
if (z_auxg != nullptr)
158+
{
159+
fftw_free(z_auxg);
160+
z_auxg = nullptr;
161+
}
162+
if (z_auxr != nullptr)
163+
{
164+
fftw_free(z_auxr);
165+
z_auxr = nullptr;
166+
}
167+
d_rspace = nullptr;
168+
}
169+
170+
template <>
171+
double* FFT_CPU<double>::get_rspace_data() const
172+
{
173+
return d_rspace;
174+
}
175+
template <>
176+
std::complex<double>* FFT_CPU<double>::get_auxr_data() const
177+
{
178+
return z_auxr;
179+
}
180+
template <>
181+
std::complex<double>* FFT_CPU<double>::get_auxg_data() const
182+
{
183+
return z_auxg;
184+
}
185+
template <>
186+
void FFT_CPU<double>::fftxyfor(std::complex<double>* in, std::complex<double>* out) const
187+
{
188+
int npy = this->nplane * this->ny;
189+
if (this->xprime)
190+
{
191+
fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out);
192+
for (int i = 0; i < this->lixy + 1; ++i)
193+
{
194+
fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
195+
}
196+
for (int i = rixy; i < this->nx; ++i)
197+
{
198+
fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
199+
}
200+
}
201+
else
202+
{
203+
for (int i = 0; i < this->nx; ++i)
204+
{
205+
fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
206+
}
207+
208+
fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out);
209+
fftw_execute_dft(this->planxfor2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]);
210+
}
211+
}
212+
template <>
213+
void FFT_CPU<double>::fftxybac(std::complex<double>* in,std::complex<double>* out) const
214+
{
215+
int npy = this->nplane * this->ny;
216+
if (this->xprime)
217+
{
218+
for (int i = 0; i < this->lixy + 1; ++i)
219+
{
220+
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
221+
}
222+
for (int i = rixy; i < this->nx; ++i)
223+
{
224+
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
225+
}
226+
fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out);
227+
}
228+
else
229+
{
230+
for (int i = 0; i < this->nx; ++i)
231+
{
232+
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
233+
}
234+
235+
fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out);
236+
fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]);
237+
}
238+
}
239+
template <>
240+
void FFT_CPU<double>::fftzfor(std::complex<double>* in, std::complex<double>* out) const
241+
{
242+
fftw_execute_dft(this->planzfor, (fftw_complex*)in, (fftw_complex*)out);
243+
}
244+
template <>
245+
void FFT_CPU<double>::fftzbac(std::complex<double>* in, std::complex<double>* out) const
246+
{
247+
fftw_execute_dft(this->planzbac, (fftw_complex*)in, (fftw_complex*)out);
248+
}
249+
template <>
250+
void FFT_CPU<double>::fftxyr2c(double* in, std::complex<double>* out) const
251+
{
252+
int npy = this->nplane * this->ny;
253+
if (this->xprime)
254+
{
255+
fftw_execute_dft_r2c(this->planxr2c, in, (fftw_complex*)out);
256+
257+
for (int i = 0; i < this->lixy + 1; ++i)
258+
{
259+
fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]);
260+
}
261+
}
262+
else
263+
{
264+
for (int i = 0; i < this->nx; ++i)
265+
{
266+
fftw_execute_dft_r2c(this->planyr2c, &in[i * npy], (fftw_complex*)&out[i * npy]);
267+
}
268+
269+
fftw_execute_dft(this->planxfor1, (fftw_complex*)out, (fftw_complex*)out);
270+
}
271+
}
272+
273+
template <>
274+
void FFT_CPU<double>::fftxyc2r(std::complex<double> *in,double *out) const
275+
{
276+
int npy = this->nplane * this->ny;
277+
if (this->xprime)
278+
{
279+
for (int i = 0; i < this->lixy + 1; ++i)
280+
{
281+
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&in[i * npy]);
282+
}
283+
284+
fftw_execute_dft_c2r(this->planxc2r, (fftw_complex*)in, out);
285+
}
286+
else
287+
{
288+
fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)in);
289+
290+
for (int i = 0; i < this->nx; ++i)
291+
{
292+
fftw_execute_dft_c2r(this->planyc2r, (fftw_complex*)&in[i * npy], &out[i * npy]);
293+
}
294+
}
295+
}
296+
template <>
297+
FFT_CPU<float>::FFT_CPU()
298+
{
299+
300+
}
301+
template <>
302+
FFT_CPU<float>::~FFT_CPU()
303+
{
304+
305+
}
306+
template FFT_CPU<float>::FFT_CPU();
307+
template FFT_CPU<double>::FFT_CPU();

0 commit comments

Comments
 (0)