Skip to content

Commit 965d627

Browse files
committed
add the file of the float_define and the device set
1 parent b40629d commit 965d627

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

source/module_base/module_fft/fft_temp.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,26 +65,24 @@ void FFT_TEMP::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in
6565
fft_float = new FFT_CPU<float>();
6666
fft_double = new FFT_CPU<double>();
6767
}
68-
else if (device=="gpu")
69-
{
70-
// #if defined(__ROCM)
71-
// fft_float = new FFT_RCOM<float>();
72-
// fft_double = new FFT_RCOM<double>();
73-
// #elif defined(__CUDA)
74-
// fft_float = new FFT_CUDA<float>();
75-
// fft_double = new FFT_CUDA<double>();
76-
// #endif
77-
}
68+
7869
if (this->precision=="single")
7970
{
8071
float_flag = true;
72+
#ifdef __ENABLE_FLOAT_FFTW
73+
float_define = true;
74+
#endif
75+
float_flag = float_define & float_flag;
8176
double_flag = true;
77+
78+
8279
}
8380
else if (this->precision=="double")
8481
{
8582
double_flag = true;
8683
}
87-
if (float_flag)
84+
85+
if (float_flag && float_define)
8886
{
8987
fft_float->initfftmode(this->fft_mode);
9088
fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in);

source/module_base/module_fft/fft_temp.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class FFT_TEMP
5555
private:
5656
int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive
5757
bool float_flag=false;
58+
bool float_define=false;
5859
bool double_flag=false;
5960
FFT_BASE<float>* fft_float=nullptr;
6061
FFT_BASE<double>* fft_double=nullptr;

source/module_basis/module_pw/pw_basis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::mo
1717
classname="PW_Basis";
1818
this->ft.set_device(this->device);
1919
this->ft.set_precision(this->precision);
20-
this->ft1.setfft(this->device,this->precision);
20+
this->ft1.setfft("cpu",this->precision);
2121
}
2222

2323
PW_Basis:: ~PW_Basis()

source/module_basis/module_pw/pw_basis_k.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace ModulePW
1212
PW_Basis_K::PW_Basis_K()
1313
{
1414
classname="PW_Basis_K";
15+
this->ft1.setfft("cpu",this->precision);
1516
}
1617
PW_Basis_K::~PW_Basis_K()
1718
{

0 commit comments

Comments
 (0)