Skip to content

Commit c73bd57

Browse files
committed
add the basic func of the file
1 parent edfa4c1 commit c73bd57

File tree

8 files changed

+753
-0
lines changed

8 files changed

+753
-0
lines changed
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
#include <cassert>
2+
#include "fft_temp.h"
3+
// #include "fft_cpu.h"
4+
#if defined(__CUDA)
5+
#include "fft_cuda.h"
6+
#endif
7+
#if defined(__ROCM)
8+
#include "fft_rcom.h"
9+
#endif
10+
#include "module_base/module_device/device.h"
11+
// #include "fft_gpu.h"
12+
FFT1::FFT1()
13+
{
14+
fft_float = nullptr;
15+
fft_double = nullptr;
16+
}
17+
FFT1::FFT1(std::string device_in,std::string precision_in)
18+
{
19+
assert(device_in=="cpu" || device_in=="gpu");
20+
assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing");
21+
this->device = device_in;
22+
this->precision = precision_in;
23+
if (device=="cpu")
24+
{
25+
fft_float = new FFT_CPU<float>();
26+
fft_double = new FFT_CPU<double>();
27+
}
28+
else if (device=="gpu")
29+
{
30+
#if defined(__ROCM)
31+
fft_float = new FFT_RCOM<float>();
32+
fft_double = new FFT_RCOM<double>();
33+
#elif defined(__CUDA)
34+
fft_float = new FFT_CUDA<float>();
35+
fft_double = new FFT_CUDA<double>();
36+
#endif
37+
}
38+
}
39+
40+
FFT1::~FFT1()
41+
{
42+
if (fft_float!=nullptr)
43+
{
44+
delete fft_float;
45+
fft_float=nullptr;
46+
}
47+
if (fft_double!=nullptr)
48+
{
49+
delete fft_double;
50+
fft_double=nullptr;
51+
}
52+
}
53+
54+
void FFT1::set_device(std::string device_in)
55+
{
56+
this->device = device_in;
57+
}
58+
59+
void FFT1::set_precision(std::string precision_in)
60+
{
61+
this->precision = precision_in;
62+
}
63+
void FFT1::setfft(std::string device_in,std::string precision_in)
64+
{
65+
assert(device_in=="cpu" || device_in=="gpu");
66+
assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing");
67+
this->device = device_in;
68+
this->precision = precision_in;
69+
if (device=="cpu")
70+
{
71+
fft_float = new FFT_CPU<float>();
72+
fft_double = new FFT_CPU<double>();
73+
}
74+
else if (device=="gpu")
75+
{
76+
#if defined(__ROCM)
77+
fft_float = new FFT_RCOM<float>();
78+
fft_double = new FFT_RCOM<double>();
79+
#elif defined(__CUDA)
80+
fft_float = new FFT_CUDA<float>();
81+
fft_double = new FFT_CUDA<double>();
82+
#endif
83+
}
84+
}
85+
void FFT1::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in,
86+
int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in)
87+
{
88+
if (this->precision=="single")
89+
{
90+
float_flag = 1;
91+
}
92+
else if (this->precision=="double")
93+
{
94+
double_flag = 1;
95+
}
96+
else if (this->precision=="mixing")
97+
{
98+
float_flag = 1;
99+
double_flag = 1;
100+
}
101+
if (float_flag)
102+
{
103+
fft_float->initfftmode(this->fft_mode);
104+
fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in);
105+
}
106+
if (double_flag)
107+
{
108+
fft_double->initfftmode(this->fft_mode);
109+
fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in);
110+
}
111+
}
112+
void FFT1::initfftmode(int fft_mode_in)
113+
{
114+
this->fft_mode = fft_mode_in;
115+
}
116+
117+
void FFT1::setupFFT()
118+
{
119+
if (double_flag)
120+
{
121+
fft_double->setupFFT();
122+
}
123+
if (float_flag)
124+
{
125+
fft_float->setupFFT();
126+
}
127+
}
128+
129+
void FFT1::clearFFT()
130+
{
131+
if (double_flag)
132+
{
133+
fft_double->cleanFFT();
134+
}
135+
if (float_flag)
136+
{
137+
fft_float->cleanFFT();
138+
}
139+
}
140+
void FFT1::clear()
141+
{
142+
this->clearFFT();
143+
if (float_flag)
144+
{
145+
fft_float->clear();
146+
}
147+
if (double_flag)
148+
{
149+
fft_double->clear();
150+
}
151+
}
152+
// access the real space data
153+
template <>
154+
float* FFT1::get_rspace_data() const
155+
{
156+
return fft_float->get_rspace_data();
157+
}
158+
159+
template <>
160+
double* FFT1::get_rspace_data() const
161+
{
162+
return fft_double->get_rspace_data();
163+
}
164+
template <>
165+
std::complex<float>* FFT1::get_auxr_data() const
166+
{
167+
return fft_float->get_auxr_data();
168+
}
169+
template <>
170+
std::complex<double>* FFT1::get_auxr_data() const
171+
{
172+
return fft_double->get_auxr_data();
173+
}
174+
template <>
175+
std::complex<float>* FFT1::get_auxg_data() const
176+
{
177+
return fft_float->get_auxg_data();
178+
}
179+
template <>
180+
std::complex<double>* FFT1::get_auxg_data() const
181+
{
182+
return fft_double->get_auxg_data();
183+
}
184+
template <>
185+
std::complex<float>* FFT1::get_auxr_3d_data() const
186+
{
187+
return fft_float->get_auxr_3d_data();
188+
}
189+
template <>
190+
std::complex<double>* FFT1::get_auxr_3d_data() const
191+
{
192+
return fft_double->get_auxr_3d_data();
193+
}
194+
template <>
195+
void FFT1::fftxyfor(std::complex<float>* in, std::complex<float>* out) const
196+
{
197+
fft_float->fftxyfor(in,out);
198+
}
199+
200+
template <>
201+
void FFT1::fftxyfor(std::complex<double>* in, std::complex<double>* out) const
202+
{
203+
fft_double->fftxyfor(in,out);
204+
}
205+
206+
template <>
207+
void FFT1::fftzfor(std::complex<float>* in, std::complex<float>* out) const
208+
{
209+
fft_float->fftzfor(in,out);
210+
}
211+
template <>
212+
void FFT1::fftzfor(std::complex<double>* in, std::complex<double>* out) const
213+
{
214+
fft_double->fftzfor(in,out);
215+
}
216+
217+
template <>
218+
void FFT1::fftxybac(std::complex<float>* in, std::complex<float>* out) const
219+
{
220+
fft_float->fftxybac(in,out);
221+
}
222+
template <>
223+
void FFT1::fftxybac(std::complex<double>* in, std::complex<double>* out) const
224+
{
225+
fft_double->fftxybac(in,out);
226+
}
227+
228+
template <>
229+
void FFT1::fftzbac(std::complex<float>* in, std::complex<float>* out) const
230+
{
231+
fft_float->fftzbac(in,out);
232+
}
233+
template <>
234+
void FFT1::fftzbac(std::complex<double>* in, std::complex<double>* out) const
235+
{
236+
fft_double->fftzbac(in,out);
237+
}
238+
template <>
239+
void FFT1::fftxyr2c(float* in, std::complex<float>* out) const
240+
{
241+
fft_float->fftxyr2c(in,out);
242+
}
243+
template <>
244+
void FFT1::fftxyr2c(double* in, std::complex<double>* out) const
245+
{
246+
fft_double->fftxyr2c(in,out);
247+
}
248+
249+
template <>
250+
void FFT1::fftxyc2r(std::complex<float>* in, float* out) const
251+
{
252+
fft_float->fftxyc2r(in,out);
253+
}
254+
template <>
255+
void FFT1::fftxyc2r(std::complex<double>* in, double* out) const
256+
{
257+
fft_double->fftxyc2r(in,out);
258+
}
259+
260+
template <>
261+
void FFT1::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex<float>* in, std::complex<float>* out) const
262+
{
263+
fft_float->fft3D_forward(in, out);
264+
}
265+
266+
template <>
267+
void FFT1::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex<double>* in, std::complex<double>* out) const
268+
{
269+
fft_double->fft3D_forward(in, out);
270+
}
271+
template <>
272+
void FFT1::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex<float>* in, std::complex<float>* out) const
273+
{
274+
fft_float->fft3D_backward(in, out);
275+
}
276+
template <>
277+
void FFT1::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex<double>* in, std::complex<double>* out) const
278+
{
279+
fft_double->fft3D_backward(in, out);
280+
}

source/module_basis/module_pw/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
list(APPEND objects
22
fft.cpp
3+
fft_base.cpp
4+
fft_temp.cpp
35
pw_basis.cpp
46
pw_basis_k.cpp
57
pw_basis_sup.cpp
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "fft_base.h"
2+
template <typename FPTYPE>
3+
FFT_BASE<FPTYPE>::FFT_BASE()
4+
{
5+
}
6+
template <typename FPTYPE>
7+
FFT_BASE<FPTYPE>::~FFT_BASE()
8+
{
9+
10+
}
11+
template <typename FPTYPE>
12+
void FFT_BASE<FPTYPE>::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in,
13+
int nproc_in, bool gamma_only_in, bool xprime_in, bool mpifft_in)
14+
{
15+
this->gamma_only = gamma_only_in;
16+
this->xprime = xprime_in;
17+
this->fftnx = this->nx = nx_in;
18+
this->fftny = this->ny = ny_in;
19+
if (this->gamma_only)
20+
{
21+
if (xprime)
22+
this->fftnx = int(nx / 2) + 1;
23+
else
24+
this->fftny = int(ny / 2) + 1;
25+
}
26+
this->nz = nz_in;
27+
this->ns = ns_in;
28+
this->lixy = lixy_in;
29+
this->rixy = rixy_in;
30+
this->nplane = nplane_in;
31+
this->nproc = nproc_in;
32+
this->mpifft = mpifft_in;
33+
this->nxy = this->nx * this->ny;
34+
this->fftnxy = this->fftnx * this->fftny;
35+
const int nrxx = this->nxy * this->nplane;
36+
const int nsz = this->nz * this->ns;
37+
this->maxgrids = (nsz > nrxx) ? nsz : nrxx;
38+
}
39+
40+
template FFT_BASE<float>::FFT_BASE();
41+
template FFT_BASE<double>::FFT_BASE();
42+
template FFT_BASE<float>::~FFT_BASE();
43+
template FFT_BASE<double>::~FFT_BASE();

0 commit comments

Comments
 (0)