1+ #include < cassert>
2+ #include " fft_temp.h"
3+ // #include "fft_cpu.h"
4+ #if defined(__CUDA)
5+ #include " fft_cuda.h"
6+ #endif
7+ #if defined(__ROCM)
8+ #include " fft_rcom.h"
9+ #endif
10+ #include " module_base/module_device/device.h"
11+ // #include "fft_gpu.h"
12+ FFT1::FFT1 ()
13+ {
14+ fft_float = nullptr ;
15+ fft_double = nullptr ;
16+ }
17+ FFT1::FFT1 (std::string device_in,std::string precision_in)
18+ {
19+ assert (device_in==" cpu" || device_in==" gpu" );
20+ assert (precision_in==" single" || precision_in==" double" || precision_in==" mixing" );
21+ this ->device = device_in;
22+ this ->precision = precision_in;
23+ if (device==" cpu" )
24+ {
25+ fft_float = new FFT_CPU<float >();
26+ fft_double = new FFT_CPU<double >();
27+ }
28+ else if (device==" gpu" )
29+ {
30+ #if defined(__ROCM)
31+ fft_float = new FFT_RCOM<float >();
32+ fft_double = new FFT_RCOM<double >();
33+ #elif defined(__CUDA)
34+ fft_float = new FFT_CUDA<float >();
35+ fft_double = new FFT_CUDA<double >();
36+ #endif
37+ }
38+ }
39+
40+ FFT1::~FFT1 ()
41+ {
42+ if (fft_float!=nullptr )
43+ {
44+ delete fft_float;
45+ fft_float=nullptr ;
46+ }
47+ if (fft_double!=nullptr )
48+ {
49+ delete fft_double;
50+ fft_double=nullptr ;
51+ }
52+ }
53+
54+ void FFT1::set_device (std::string device_in)
55+ {
56+ this ->device = device_in;
57+ }
58+
59+ void FFT1::set_precision (std::string precision_in)
60+ {
61+ this ->precision = precision_in;
62+ }
63+ void FFT1::setfft (std::string device_in,std::string precision_in)
64+ {
65+ assert (device_in==" cpu" || device_in==" gpu" );
66+ assert (precision_in==" single" || precision_in==" double" || precision_in==" mixing" );
67+ this ->device = device_in;
68+ this ->precision = precision_in;
69+ if (device==" cpu" )
70+ {
71+ fft_float = new FFT_CPU<float >();
72+ fft_double = new FFT_CPU<double >();
73+ }
74+ else if (device==" gpu" )
75+ {
76+ #if defined(__ROCM)
77+ fft_float = new FFT_RCOM<float >();
78+ fft_double = new FFT_RCOM<double >();
79+ #elif defined(__CUDA)
80+ fft_float = new FFT_CUDA<float >();
81+ fft_double = new FFT_CUDA<double >();
82+ #endif
83+ }
84+ }
85+ void FFT1::initfft (int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in,
86+ int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in)
87+ {
88+ if (this ->precision ==" single" )
89+ {
90+ float_flag = 1 ;
91+ }
92+ else if (this ->precision ==" double" )
93+ {
94+ double_flag = 1 ;
95+ }
96+ else if (this ->precision ==" mixing" )
97+ {
98+ float_flag = 1 ;
99+ double_flag = 1 ;
100+ }
101+ if (float_flag)
102+ {
103+ fft_float->initfftmode (this ->fft_mode );
104+ fft_float->initfft (nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in);
105+ }
106+ if (double_flag)
107+ {
108+ fft_double->initfftmode (this ->fft_mode );
109+ fft_double->initfft (nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in);
110+ }
111+ }
112+ void FFT1::initfftmode (int fft_mode_in)
113+ {
114+ this ->fft_mode = fft_mode_in;
115+ }
116+
117+ void FFT1::setupFFT ()
118+ {
119+ if (double_flag)
120+ {
121+ fft_double->setupFFT ();
122+ }
123+ if (float_flag)
124+ {
125+ fft_float->setupFFT ();
126+ }
127+ }
128+
129+ void FFT1::clearFFT ()
130+ {
131+ if (double_flag)
132+ {
133+ fft_double->cleanFFT ();
134+ }
135+ if (float_flag)
136+ {
137+ fft_float->cleanFFT ();
138+ }
139+ }
140+ void FFT1::clear ()
141+ {
142+ this ->clearFFT ();
143+ if (float_flag)
144+ {
145+ fft_float->clear ();
146+ }
147+ if (double_flag)
148+ {
149+ fft_double->clear ();
150+ }
151+ }
152+ // access the real space data
153+ template <>
154+ float * FFT1::get_rspace_data () const
155+ {
156+ return fft_float->get_rspace_data ();
157+ }
158+
159+ template <>
160+ double * FFT1::get_rspace_data () const
161+ {
162+ return fft_double->get_rspace_data ();
163+ }
164+ template <>
165+ std::complex <float >* FFT1::get_auxr_data () const
166+ {
167+ return fft_float->get_auxr_data ();
168+ }
169+ template <>
170+ std::complex <double >* FFT1::get_auxr_data () const
171+ {
172+ return fft_double->get_auxr_data ();
173+ }
174+ template <>
175+ std::complex <float >* FFT1::get_auxg_data () const
176+ {
177+ return fft_float->get_auxg_data ();
178+ }
179+ template <>
180+ std::complex <double >* FFT1::get_auxg_data () const
181+ {
182+ return fft_double->get_auxg_data ();
183+ }
184+ template <>
185+ std::complex <float >* FFT1::get_auxr_3d_data () const
186+ {
187+ return fft_float->get_auxr_3d_data ();
188+ }
189+ template <>
190+ std::complex <double >* FFT1::get_auxr_3d_data () const
191+ {
192+ return fft_double->get_auxr_3d_data ();
193+ }
194+ template <>
195+ void FFT1::fftxyfor (std::complex <float >* in, std::complex <float >* out) const
196+ {
197+ fft_float->fftxyfor (in,out);
198+ }
199+
200+ template <>
201+ void FFT1::fftxyfor (std::complex <double >* in, std::complex <double >* out) const
202+ {
203+ fft_double->fftxyfor (in,out);
204+ }
205+
206+ template <>
207+ void FFT1::fftzfor (std::complex <float >* in, std::complex <float >* out) const
208+ {
209+ fft_float->fftzfor (in,out);
210+ }
211+ template <>
212+ void FFT1::fftzfor (std::complex <double >* in, std::complex <double >* out) const
213+ {
214+ fft_double->fftzfor (in,out);
215+ }
216+
217+ template <>
218+ void FFT1::fftxybac (std::complex <float >* in, std::complex <float >* out) const
219+ {
220+ fft_float->fftxybac (in,out);
221+ }
222+ template <>
223+ void FFT1::fftxybac (std::complex <double >* in, std::complex <double >* out) const
224+ {
225+ fft_double->fftxybac (in,out);
226+ }
227+
228+ template <>
229+ void FFT1::fftzbac (std::complex <float >* in, std::complex <float >* out) const
230+ {
231+ fft_float->fftzbac (in,out);
232+ }
233+ template <>
234+ void FFT1::fftzbac (std::complex <double >* in, std::complex <double >* out) const
235+ {
236+ fft_double->fftzbac (in,out);
237+ }
238+ template <>
239+ void FFT1::fftxyr2c (float * in, std::complex <float >* out) const
240+ {
241+ fft_float->fftxyr2c (in,out);
242+ }
243+ template <>
244+ void FFT1::fftxyr2c (double * in, std::complex <double >* out) const
245+ {
246+ fft_double->fftxyr2c (in,out);
247+ }
248+
249+ template <>
250+ void FFT1::fftxyc2r (std::complex <float >* in, float * out) const
251+ {
252+ fft_float->fftxyc2r (in,out);
253+ }
254+ template <>
255+ void FFT1::fftxyc2r (std::complex <double >* in, double * out) const
256+ {
257+ fft_double->fftxyc2r (in,out);
258+ }
259+
260+ template <>
261+ void FFT1::fft3D_forward (const base_device::DEVICE_GPU* ctx, std::complex <float >* in, std::complex <float >* out) const
262+ {
263+ fft_float->fft3D_forward (in, out);
264+ }
265+
266+ template <>
267+ void FFT1::fft3D_forward (const base_device::DEVICE_GPU* ctx, std::complex <double >* in, std::complex <double >* out) const
268+ {
269+ fft_double->fft3D_forward (in, out);
270+ }
271+ template <>
272+ void FFT1::fft3D_backward (const base_device::DEVICE_GPU* ctx, std::complex <float >* in, std::complex <float >* out) const
273+ {
274+ fft_float->fft3D_backward (in, out);
275+ }
276+ template <>
277+ void FFT1::fft3D_backward (const base_device::DEVICE_GPU* ctx, std::complex <double >* in, std::complex <double >* out) const
278+ {
279+ fft_double->fft3D_backward (in, out);
280+ }
0 commit comments