Skip to content

Commit a3d4613

Browse files
committed
update the FFT
1 parent b1c7538 commit a3d4613

File tree

9 files changed

+16
-43
lines changed

9 files changed

+16
-43
lines changed

source/module_basis/module_pw/pw_basis.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ PW_Basis::PW_Basis()
1515

1616
PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) {
1717
classname="PW_Basis";
18-
this->ft.set_device(this->device);
19-
this->ft.set_precision(this->precision);
2018
this->fft_bundle.setfft("cpu",this->precision);
2119
}
2220

@@ -57,19 +55,15 @@ void PW_Basis::setuptransform()
5755
this->distribute_r();
5856
this->distribute_g();
5957
this->getstartgr();
60-
this->ft.clear();
6158
this->fft_bundle.clear();
6259
if(this->xprime)
6360
{
64-
this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);
6561
this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);
6662
}
6763
else
6864
{
69-
this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);
7065
this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);
7166
}
72-
this->ft.setupFFT();
7367
this->fft_bundle.setupFFT();
7468
ModuleBase::timer::tick(this->classname, "setuptransform");
7569
}

source/module_basis/module_pw/pw_basis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ class PW_Basis
242242
int ng_xeq0 = 0; //only used when xprime = true, number of g whose gx = 0
243243
int nmaxgr=0; // Gamma_only: max between npw and (nrxx+1)/2, others: max between npw and nrxx
244244
// Thus complex<double>[nmaxgr] is able to contain either reciprocal or real data
245-
FFT ft;
245+
// FFT ft;
246246
FFT_Bundle fft_bundle;
247247
//The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform).
248248

source/module_basis/module_pw/pw_basis_k.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace ModulePW
1212
PW_Basis_K::PW_Basis_K()
1313
{
1414
classname="PW_Basis_K";
15-
this->fft_bundle.setfft("cpu",this->precision);
15+
this->fft_bundle.setfft(this->device,this->precision);
1616
}
1717
PW_Basis_K::~PW_Basis_K()
1818
{
@@ -184,16 +184,13 @@ void PW_Basis_K::setuptransform()
184184
this->distribute_g();
185185
this->getstartgr();
186186
this->setupIndGk();
187-
this->ft.clear();
188187
this->fft_bundle.clear();
188+
this->fft_bundle.setfft(this->device,this->precision);
189189
if(this->xprime){
190-
this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);
191190
this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);
192191
}else{
193-
this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);
194192
this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);
195193
}
196-
this->ft.setupFFT();
197194
this->fft_bundle.setupFFT();
198195
ModuleBase::timer::tick(this->classname, "setuptransform");
199196
}

source/module_basis/module_pw/pw_basis_sup.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,9 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho)
1919
this->distribute_r();
2020
this->distribute_g(pw_rho);
2121
this->getstartgr();
22-
this->ft.clear();
2322
this->fft_bundle.clear();
2423
if (this->xprime)
2524
{
26-
this->ft.initfft(this->nx,
27-
this->ny,
28-
this->nz,
29-
this->lix,
30-
this->rix,
31-
this->nst,
32-
this->nplane,
33-
this->poolnproc,
34-
this->gamma_only,
35-
this->xprime);
3625
this->fft_bundle.initfft(this->nx,
3726
this->ny,
3827
this->nz,
@@ -46,16 +35,6 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho)
4635
}
4736
else
4837
{
49-
this->ft.initfft(this->nx,
50-
this->ny,
51-
this->nz,
52-
this->liy,
53-
this->riy,
54-
this->nst,
55-
this->nplane,
56-
this->poolnproc,
57-
this->gamma_only,
58-
this->xprime);
5938
this->fft_bundle.initfft(this->nx,
6039
this->ny,
6140
this->nz,
@@ -67,7 +46,6 @@ void PW_Basis_Sup::setuptransform(const ModulePW::PW_Basis* pw_rho)
6746
this->gamma_only,
6847
this->xprime);
6948
}
70-
this->ft.setupFFT();
7149
this->fft_bundle.setupFFT();
7250
ModuleBase::timer::tick(this->classname, "setuptransform");
7351
}

source/module_basis/module_pw/test_serial/pw_basis_k_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ TEST_F(PWBasisKTEST,Constructor)
4646
EXPECT_EQ(basis_k1.classname,"PW_Basis_K");
4747
EXPECT_EQ(basis_k2.classname,"PW_Basis_K");
4848
EXPECT_EQ(basis_k2.device,"cpu");
49-
EXPECT_EQ(basis_k2.ft.device,"cpu");
49+
EXPECT_EQ(basis_k2.fft_bundle.device,"cpu");
5050
EXPECT_EQ(basis_k2.precision,"double");
51-
EXPECT_EQ(basis_k2.ft.precision,"double");
51+
EXPECT_EQ(basis_k2.fft_bundle.precision,"double");
5252
ModulePW::PW_Basis_K basis_k3(device_flag, precision_single);
53-
EXPECT_EQ(basis_k3.ft.precision,"single");
53+
EXPECT_EQ(basis_k3.fft_bundle.precision,"single");
5454
}
5555

5656
TEST_F(PWBasisKTEST,Initgrids1)

source/module_basis/module_pw/test_serial/pw_basis_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ TEST_F(PWBasisTEST,Constructor)
5858
EXPECT_EQ(pwb2.classname,"PW_Basis");
5959
EXPECT_EQ(pwb2.device,"cpu");
6060
EXPECT_EQ(pwb2.precision,"double");
61-
EXPECT_EQ(pwb2.ft.device,"cpu");
62-
EXPECT_EQ(pwb2.ft.precision,"double");
61+
EXPECT_EQ(pwb2.fft_bundle.device,"cpu");
62+
EXPECT_EQ(pwb2.fft_bundle.precision,"double");
6363
}
6464

6565
TEST_F(PWBasisTEST,Initgrids1)

source/module_esolver/esolver_fp.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell)
8383
}
8484

8585
this->pw_rho->initparameters(false, 4.0 * inp.ecutwfc);
86-
this->pw_rho->ft.fft_mode = inp.fft_mode;
8786
this->pw_rho->fft_bundle.initfftmode(inp.fft_mode);
8887
this->pw_rho->setuptransform();
8988
this->pw_rho->collect_local_pw();
@@ -109,7 +108,6 @@ void ESolver_FP::before_all_runners(const Input_para& inp, UnitCell& cell)
109108
this->pw_rhod->initgrids(inp.ref_cell_factor * cell.lat0, cell.latvec, inp.ndx, inp.ndy, inp.ndz);
110109
}
111110
this->pw_rhod->initparameters(false, inp.ecutrho);
112-
this->pw_rhod->ft.fft_mode = inp.fft_mode;
113111
this->pw_rhod->fft_bundle.initfftmode(inp.fft_mode);
114112
pw_rhod_sup->setuptransform(this->pw_rho);
115113
this->pw_rhod->collect_local_pw();

source/module_esolver/esolver_ks.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ void ESolver_KS<T, Device>::before_all_runners(const Input_para& inp, UnitCell&
246246
// results
247247
#endif
248248

249-
this->pw_wfc->ft.fft_mode = inp.fft_mode;
250249
this->pw_wfc->fft_bundle.initfftmode(inp.fft_mode);
251250
this->pw_wfc->setuptransform();
252251

source/module_hamilt_general/module_xc/test/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ AddTest(
2121
../xc_functional_libxc_wrapper_gcxc.cpp ../xc_functional_libxc_wrapper_xc.cpp ../xc_functional_libxc.cpp
2222
)
2323

24-
24+
if (USE_CUDA)
25+
list(APPEND FFT_SRC ../../../module_basis/module_pw/module_fft/fft_cuda.cpp)
26+
endif()
27+
if (USE_ROCM)
28+
list(APPEND FFT_SRC ../../../module_basis/module_pw/module_fft/fft_rocm.cpp)
29+
endif()
2530
AddTest(
2631
TARGET XCTest_GRADCORR
2732
LIBS parameter MPI::MPI_CXX Libxc::xc ${math_libs} psi device container
@@ -41,6 +46,7 @@ AddTest(
4146
../../../module_basis/module_pw/module_fft/fft_base.cpp
4247
../../../module_basis/module_pw/module_fft/fft_bundle.cpp
4348
../../../module_basis/module_pw/module_fft/fft_cpu.cpp
49+
${FFT_SRC}
4450
)
4551

4652
AddTest(
@@ -79,4 +85,5 @@ AddTest(
7985
../../../module_basis/module_pw/module_fft/fft_base.cpp
8086
../../../module_basis/module_pw/module_fft/fft_bundle.cpp
8187
../../../module_basis/module_pw/module_fft/fft_cpu.cpp
88+
${FFT_SRC}
8289
)

0 commit comments

Comments
 (0)