Skip to content

Commit 2e13c7f

Browse files
committed
update eslover before all runners
1 parent 385b010 commit 2e13c7f

File tree

4 files changed

+58
-55
lines changed

4 files changed

+58
-55
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ abacus.json
2424
*.npy
2525
toolchain/install/
2626
toolchain/abacus_env.sh
27+
*.sh
28+
*.pyc
29+
*.txt
30+
*.py

source/module_basis/module_pw/pw_basis.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,7 @@ PW_Basis::PW_Basis()
1515

1616
PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) {
1717
classname="PW_Basis";
18-
std::string fft_precison;
19-
if ((this->precision=="single") || (this->precision=="mixing"))
20-
{
21-
fft_precison = "mixing";
22-
}
23-
else if (this->precision=="double")
24-
{
25-
fft_precison = "double";
26-
}
27-
#if (not defined(__ENABLE_FLOAT_FFTW) and (defined(__CUDA) || defined(__RCOM)))
28-
if (this->device == "gpu")
29-
{
30-
fft_precison = "double";
31-
}
32-
#endif
33-
this->fft_bundle.setfft("cpu",fft_precison);
18+
this->fft_bundle.setfft("cpu",this->precision);
3419
this->double_data_ = (this->precision == "double") || (this->precision == "mixing");
3520
this->float_data_ = (this->precision == "single") || (this->precision == "mixing");
3621
}

source/module_esolver/esolver_fp.cpp

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,43 +23,55 @@ namespace ModuleESolver
2323

2424
ESolver_FP::ESolver_FP()
2525
{
26+
}
27+
28+
ESolver_FP::~ESolver_FP()
29+
{
30+
delete pw_rho;
31+
if ( PARAM.globalv.double_grid)
32+
{
33+
delete pw_rhod;
34+
}
35+
delete this->pelec;
36+
}
37+
38+
void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp)
39+
{
40+
ModuleBase::TITLE("ESolver_FP", "before_all_runners");
2641
std::string fft_device = PARAM.inp.device;
42+
std::string fft_precison = PARAM.inp.precision;
2743
// LCAO basis doesn't support GPU acceleration on FFT currently
2844
if(PARAM.inp.basis_type == "lcao")
2945
{
3046
fft_device = "cpu";
3147
}
32-
pw_rho = new ModulePW::PW_Basis_Big(fft_device, PARAM.inp.precision);
48+
if ((PARAM.inp.precision=="single") || (PARAM.inp.precision=="mixing"))
49+
{
50+
fft_precison = "mixing";
51+
}
52+
else if (PARAM.inp.precision=="double")
53+
{
54+
fft_precison = "double";
55+
}
56+
#if (not defined(__ENABLE_FLOAT_FFTW) and (defined(__CUDA) || defined(__RCOM)))
57+
if (this->device == "gpu")
58+
{
59+
fft_precison = "double";
60+
}
61+
#endif
62+
pw_rho = new ModulePW::PW_Basis_Big(fft_device, fft_precison);
3363
if (PARAM.globalv.double_grid)
3464
{
35-
pw_rhod = new ModulePW::PW_Basis_Big(fft_device, PARAM.inp.precision);
65+
pw_rhod = new ModulePW::PW_Basis_Big(fft_device, fft_precison);
3666
}
3767
else
3868
{
3969
pw_rhod = pw_rho;
4070
}
41-
42-
// temporary, it will be removed
4371
pw_big = static_cast<ModulePW::PW_Basis_Big*>(pw_rhod);
4472
pw_big->setbxyz(PARAM.inp.bx, PARAM.inp.by, PARAM.inp.bz);
4573
sf.set(pw_rhod, PARAM.inp.nbspline);
4674

47-
}
48-
49-
ESolver_FP::~ESolver_FP()
50-
{
51-
delete pw_rho;
52-
if ( PARAM.globalv.double_grid)
53-
{
54-
delete pw_rhod;
55-
}
56-
delete this->pelec;
57-
}
58-
59-
void ESolver_FP::before_all_runners(UnitCell& ucell, const Input_para& inp)
60-
{
61-
ModuleBase::TITLE("ESolver_FP", "before_all_runners");
62-
6375
//! 1) read pseudopotentials
6476
if (!PARAM.inp.use_paw)
6577
{

source/module_esolver/esolver_ks.cpp

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,27 @@ namespace ModuleESolver
3636
template <typename T, typename Device>
3737
ESolver_KS<T, Device>::ESolver_KS()
3838
{
39+
}
40+
41+
42+
template <typename T, typename Device>
43+
ESolver_KS<T, Device>::~ESolver_KS()
44+
{
45+
delete this->psi;
46+
delete this->pw_wfc;
47+
delete this->p_hamilt;
48+
delete this->p_chgmix;
49+
this->ppcell.release_memory();
50+
}
51+
52+
53+
template <typename T, typename Device>
54+
void ESolver_KS<T, Device>::before_all_runners(UnitCell& ucell, const Input_para& inp)
55+
{
56+
ModuleBase::TITLE("ESolver_KS", "before_all_runners");
57+
//! 1) initialize "before_all_runniers" in ESolver_FP
58+
ESolver_FP::before_all_runners(ucell, inp);
59+
3960
classname = "ESolver_KS";
4061
basisname = "PLEASE ADD BASISNAME FOR CURRENT ESOLVER.";
4162

@@ -75,27 +96,8 @@ ESolver_KS<T, Device>::ESolver_KS()
7596

7697
// cell_factor
7798
this->ppcell.cell_factor = PARAM.inp.cell_factor;
78-
}
79-
80-
81-
template <typename T, typename Device>
82-
ESolver_KS<T, Device>::~ESolver_KS()
83-
{
84-
delete this->psi;
85-
delete this->pw_wfc;
86-
delete this->p_hamilt;
87-
delete this->p_chgmix;
88-
this->ppcell.release_memory();
89-
}
9099

91100

92-
template <typename T, typename Device>
93-
void ESolver_KS<T, Device>::before_all_runners(UnitCell& ucell, const Input_para& inp)
94-
{
95-
ModuleBase::TITLE("ESolver_KS", "before_all_runners");
96-
97-
//! 1) initialize "before_all_runniers" in ESolver_FP
98-
ESolver_FP::before_all_runners(ucell, inp);
99101

100102
/// PAW Section
101103
#ifdef USE_PAW

0 commit comments

Comments
 (0)