Skip to content

Commit 9c1a22d

Browse files
committed
update the file
1 parent 61bb766 commit 9c1a22d

File tree

7 files changed

+515
-198
lines changed

7 files changed

+515
-198
lines changed
Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
#include "fft_base.h"
22
namespace ModulePW
33
{
4-
template <typename FPTYPE>
5-
FFT_BASE<FPTYPE>::FFT_BASE()
6-
{
7-
}
8-
template <typename FPTYPE>
9-
FFT_BASE<FPTYPE>::~FFT_BASE()
10-
{
11-
}
4+
// template <typename FPTYPE>
5+
// FFT_BASE<FPTYPE>::FFT_BASE()
6+
// {
7+
// }
8+
// template <typename FPTYPE>
9+
// FFT_BASE<FPTYPE>::~FFT_BASE()
10+
// {
11+
// }
1212

13-
template FFT_BASE<float>::FFT_BASE();
14-
template FFT_BASE<double>::FFT_BASE();
15-
template FFT_BASE<float>::~FFT_BASE();
16-
template FFT_BASE<double>::~FFT_BASE();
13+
// template FFT_BASE<float>::FFT_BASE();
14+
// template FFT_BASE<double>::FFT_BASE();
15+
// template FFT_BASE<float>::~FFT_BASE();
16+
// template FFT_BASE<double>::~FFT_BASE();
1717
}

source/module_basis/module_pw/module_fft/fft_base.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ class FFT_BASE
1010
{
1111
public:
1212

13-
FFT_BASE();
14-
virtual ~FFT_BASE();
13+
FFT_BASE(){};
14+
virtual ~FFT_BASE(){};
1515

1616
/**
1717
* @brief Initialize the fft parameters As virtual function.
@@ -159,5 +159,9 @@ class FFT_BASE
159159
int ny=0;
160160
int nz=0;
161161
};
162+
template FFT_BASE<float>::FFT_BASE();
163+
template FFT_BASE<double>::FFT_BASE();
164+
template FFT_BASE<float>::~FFT_BASE();
165+
template FFT_BASE<double>::~FFT_BASE();
162166
}
163167
#endif // FFT_BASE_H

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 78 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,29 @@ void FFT_Bundle::initfft(int nx_in,
7979
fft_double = make_unique<FFT_CPU<double>>(this->fft_mode);
8080
if (float_flag)
8181
{
82-
fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in);
82+
fft_float->initfft(nx_in,
83+
ny_in,
84+
nz_in,
85+
lixy_in,
86+
rixy_in,
87+
ns_in,
88+
nplane_in,
89+
nproc_in,
90+
gamma_only_in,
91+
xprime_in);
8392
}
8493
if (double_flag)
8594
{
86-
fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in);
95+
fft_double->initfft(nx_in,
96+
ny_in,
97+
nz_in,
98+
lixy_in,
99+
rixy_in,
100+
ns_in,
101+
nplane_in,
102+
nproc_in,
103+
gamma_only_in,
104+
xprime_in);
87105
}
88106
}
89107
if (device=="gpu")
@@ -138,133 +156,134 @@ void FFT_Bundle::clear()
138156
fft_double->clear();
139157
}
140158
}
141-
// access the real space data
142-
template <>
143-
float* FFT_Bundle::get_rspace_data() const
144-
{
145-
return fft_float->get_rspace_data();
146-
}
159+
147160

148161
template <>
149-
double* FFT_Bundle::get_rspace_data() const
150-
{
151-
return fft_double->get_rspace_data();
152-
}
153-
template <>
154-
std::complex<float>* FFT_Bundle::get_auxr_data() const
155-
{
156-
return fft_float->get_auxr_data();
157-
}
158-
template <>
159-
std::complex<double>* FFT_Bundle::get_auxr_data() const
160-
{
161-
return fft_double->get_auxr_data();
162-
}
163-
template <>
164-
std::complex<float>* FFT_Bundle::get_auxg_data() const
165-
{
166-
return fft_float->get_auxg_data();
167-
}
168-
template <>
169-
std::complex<double>* FFT_Bundle::get_auxg_data() const
170-
{
171-
return fft_double->get_auxg_data();
172-
}
173-
template <>
174-
std::complex<float>* FFT_Bundle::get_auxr_3d_data() const
175-
{
176-
return fft_float->get_auxr_3d_data();
177-
}
178-
template <>
179-
std::complex<double>* FFT_Bundle::get_auxr_3d_data() const
180-
{
181-
return fft_double->get_auxr_3d_data();
182-
}
183-
template <>
184-
void FFT_Bundle::fftxyfor(std::complex<float>* in, std::complex<float>* out) const
162+
void FFT_Bundle::fftxyfor(std::complex<float>* in,
163+
std::complex<float>* out) const
185164
{
186165
fft_float->fftxyfor(in,out);
187166
}
188167

189168
template <>
190-
void FFT_Bundle::fftxyfor(std::complex<double>* in, std::complex<double>* out) const
169+
void FFT_Bundle::fftxyfor(std::complex<double>* in,
170+
std::complex<double>* out) const
191171
{
192172
fft_double->fftxyfor(in,out);
193173
}
194174

195175
template <>
196-
void FFT_Bundle::fftzfor(std::complex<float>* in, std::complex<float>* out) const
176+
void FFT_Bundle::fftzfor(std::complex<float>* in,
177+
std::complex<float>* out) const
197178
{
198179
fft_float->fftzfor(in,out);
199180
}
200181
template <>
201-
void FFT_Bundle::fftzfor(std::complex<double>* in, std::complex<double>* out) const
182+
void FFT_Bundle::fftzfor(std::complex<double>* in,
183+
std::complex<double>* out) const
202184
{
203185
fft_double->fftzfor(in,out);
204186
}
205187

206188
template <>
207-
void FFT_Bundle::fftxybac(std::complex<float>* in, std::complex<float>* out) const
189+
void FFT_Bundle::fftxybac(std::complex<float>* in,
190+
std::complex<float>* out) const
208191
{
209192
fft_float->fftxybac(in,out);
210193
}
211194
template <>
212-
void FFT_Bundle::fftxybac(std::complex<double>* in, std::complex<double>* out) const
195+
void FFT_Bundle::fftxybac(std::complex<double>* in,
196+
std::complex<double>* out) const
213197
{
214198
fft_double->fftxybac(in,out);
215199
}
216200

217201
template <>
218-
void FFT_Bundle::fftzbac(std::complex<float>* in, std::complex<float>* out) const
202+
void FFT_Bundle::fftzbac(std::complex<float>* in,
203+
std::complex<float>* out) const
219204
{
220205
fft_float->fftzbac(in,out);
221206
}
222207
template <>
223-
void FFT_Bundle::fftzbac(std::complex<double>* in, std::complex<double>* out) const
208+
void FFT_Bundle::fftzbac(std::complex<double>* in,
209+
std::complex<double>* out) const
224210
{
225211
fft_double->fftzbac(in,out);
226212
}
227213
template <>
228-
void FFT_Bundle::fftxyr2c(float* in, std::complex<float>* out) const
214+
void FFT_Bundle::fftxyr2c(float* in,
215+
std::complex<float>* out) const
229216
{
230217
fft_float->fftxyr2c(in,out);
231218
}
232219
template <>
233-
void FFT_Bundle::fftxyr2c(double* in, std::complex<double>* out) const
220+
void FFT_Bundle::fftxyr2c(double* in,
221+
std::complex<double>* out) const
234222
{
235223
fft_double->fftxyr2c(in,out);
236224
}
237225

238226
template <>
239-
void FFT_Bundle::fftxyc2r(std::complex<float>* in, float* out) const
227+
void FFT_Bundle::fftxyc2r(std::complex<float>* in,
228+
float* out) const
240229
{
241230
fft_float->fftxyc2r(in,out);
242231
}
243232
template <>
244-
void FFT_Bundle::fftxyc2r(std::complex<double>* in, double* out) const
233+
void FFT_Bundle::fftxyc2r(std::complex<double>* in,
234+
double* out) const
245235
{
246236
fft_double->fftxyc2r(in,out);
247237
}
248238

249239
template <>
250-
void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex<float>* in, std::complex<float>* out) const
240+
void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx,
241+
std::complex<float>* in,
242+
std::complex<float>* out) const
251243
{
252244
fft_float->fft3D_forward(in, out);
253245
}
254246

255247
template <>
256-
void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, std::complex<double>* in, std::complex<double>* out) const
248+
void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx,
249+
std::complex<double>* in,
250+
std::complex<double>* out) const
257251
{
258252
fft_double->fft3D_forward(in, out);
259253
}
260254
template <>
261-
void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex<float>* in, std::complex<float>* out) const
255+
void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx,
256+
std::complex<float>* in,
257+
std::complex<float>* out) const
262258
{
263259
fft_float->fft3D_backward(in, out);
264260
}
265261
template <>
266-
void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, std::complex<double>* in, std::complex<double>* out) const
262+
void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx,
263+
std::complex<double>* in,
264+
std::complex<double>* out) const
267265
{
268266
fft_double->fft3D_backward(in, out);
269267
}
268+
269+
// access the real space data
270+
template <> float*
271+
FFT_Bundle::get_rspace_data() const {return fft_float->get_rspace_data();}
272+
template <> double*
273+
FFT_Bundle::get_rspace_data() const {return fft_double->get_rspace_data();}
274+
275+
template <> std::complex<float>*
276+
FFT_Bundle::get_auxr_data() const {return fft_float->get_auxr_data();}
277+
template <> std::complex<double>*
278+
FFT_Bundle::get_auxr_data() const{return fft_double->get_auxr_data();}
279+
280+
template <> std::complex<float>*
281+
FFT_Bundle::get_auxg_data() const{return fft_float->get_auxg_data();}
282+
template <> std::complex<double>*
283+
FFT_Bundle::get_auxg_data() const{return fft_double->get_auxg_data();}
284+
285+
template <> std::complex<float>*
286+
FFT_Bundle::get_auxr_3d_data() const{return fft_float->get_auxr_3d_data();}
287+
template <> std::complex<double>*
288+
FFT_Bundle::get_auxr_3d_data() const {return fft_double->get_auxr_3d_data();}
270289
}

source/module_basis/module_pw/module_fft/fft_bundle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ class FFT_Bundle
206206
bool float_flag=false;
207207
bool float_define=true;
208208
bool double_flag=false;
209-
std::shared_ptr<FFT_BASE<float>> fft_float=nullptr;
210-
std::shared_ptr<FFT_BASE<double>> fft_double=nullptr;
209+
std::unique_ptr<FFT_BASE<float>> fft_float=nullptr;
210+
std::unique_ptr<FFT_BASE<double>> fft_double=nullptr;
211211

212212
std::string device = "cpu";
213213
std::string precision = "double";

0 commit comments

Comments
 (0)