Skip to content

Commit 3b096c6

Browse files
committed
add template <Device> for sdft
1 parent d7d52c5 commit 3b096c6

File tree

8 files changed

+212
-192
lines changed

8 files changed

+212
-192
lines changed
Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,46 @@
11
#include "./elecstate_pw_sdft.h"
2+
3+
#include "module_base/global_function.h"
24
#include "module_base/global_variable.h"
3-
#include "module_parameter/parameter.h"
45
#include "module_base/timer.h"
5-
#include "module_base/global_function.h"
66
#include "module_hamilt_general/module_xc/xc_functional.h"
7+
#include "module_parameter/parameter.h"
78
namespace elecstate
89
{
9-
void ElecStatePW_SDFT::psiToRho(const psi::Psi<std::complex<double>>& psi)
10+
11+
template <typename Device>
12+
void ElecStatePW_SDFT<Device>::psiToRho(const psi::Psi<std::complex<double>>& psi)
13+
{
14+
ModuleBase::TITLE(this->classname, "psiToRho");
15+
ModuleBase::timer::tick(this->classname, "psiToRho");
16+
for (int is = 0; is < PARAM.inp.nspin; is++)
1017
{
11-
ModuleBase::TITLE(this->classname, "psiToRho");
12-
ModuleBase::timer::tick(this->classname, "psiToRho");
13-
for(int is=0; is < PARAM.inp.nspin; is++)
14-
{
15-
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
16-
if (XC_Functional::get_func_type() == 3)
17-
{
18-
ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
19-
}
20-
}
21-
22-
if(GlobalV::MY_STOGROUP == 0)
23-
{
24-
this->calEBand();
18+
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
19+
if (XC_Functional::get_func_type() == 3)
20+
{
21+
ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
22+
}
23+
}
24+
25+
if (GlobalV::MY_STOGROUP == 0)
26+
{
27+
this->calEBand();
2528

26-
for(int is=0; is<PARAM.inp.nspin; is++)
27-
{
28-
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
29-
}
29+
for (int is = 0; is < PARAM.inp.nspin; is++)
30+
{
31+
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
32+
}
3033

31-
for (int ik = 0; ik < psi.get_nk(); ++ik)
32-
{
33-
psi.fix_k(ik);
34-
this->updateRhoK(psi);
35-
}
36-
this->parallelK();
34+
for (int ik = 0; ik < psi.get_nk(); ++ik)
35+
{
36+
psi.fix_k(ik);
37+
this->updateRhoK(psi);
3738
}
38-
ModuleBase::timer::tick(this->classname, "psiToRho");
39-
return;
39+
this->parallelK();
4040
}
41-
}
41+
ModuleBase::timer::tick(this->classname, "psiToRho");
42+
return;
43+
}
44+
45+
template class ElecStatePW_SDFT<base_device::DEVICE_CPU>;
46+
} // namespace elecstate

source/module_elecstate/elecstate_pw_sdft.h

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@
33
#include "elecstate_pw.h"
44
namespace elecstate
55
{
6-
class ElecStatePW_SDFT : public ElecStatePW<std::complex<double>>
6+
template <typename Device>
7+
class ElecStatePW_SDFT : public ElecStatePW<std::complex<double>, Device>
8+
{
9+
public:
10+
ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in,
11+
Charge* chg_in,
12+
K_Vectors* pkv_in,
13+
UnitCell* ucell_in,
14+
pseudopot_cell_vnl* ppcell_in,
15+
ModulePW::PW_Basis* rhodpw_in,
16+
ModulePW::PW_Basis* rhopw_in,
17+
ModulePW::PW_Basis_Big* bigpw_in)
18+
: ElecStatePW<std::complex<double>,
19+
Device>(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in)
720
{
8-
public:
9-
ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in,
10-
Charge* chg_in,
11-
K_Vectors* pkv_in,
12-
UnitCell* ucell_in,
13-
pseudopot_cell_vnl* ppcell_in,
14-
ModulePW::PW_Basis* rhodpw_in,
15-
ModulePW::PW_Basis* rhopw_in,
16-
ModulePW::PW_Basis_Big* bigpw_in)
17-
: ElecStatePW(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in)
18-
{
19-
this->classname = "ElecStatePW_SDFT";
20-
}
21-
virtual void psiToRho(const psi::Psi<std::complex<double>>& psi) override;
22-
};
23-
}
21+
this->classname = "ElecStatePW_SDFT";
22+
}
23+
virtual void psiToRho(const psi::Psi<std::complex<double>>& psi) override;
24+
};
25+
} // namespace elecstate
2426
#endif

source/module_esolver/esolver.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
153153
return new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_CPU>();
154154
}
155155
}
156+
else if (esolver_type == "sdft_pw")
157+
{
158+
return new ESolver_SDFT_PW<base_device::DEVICE_CPU>();
159+
}
156160
#ifdef __LCAO
157161
else if (esolver_type == "ksdft_lip")
158162
{
@@ -230,10 +234,6 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
230234
return p_esolver_lr;
231235
}
232236
#endif
233-
else if (esolver_type == "sdft_pw")
234-
{
235-
return new ESolver_SDFT_PW();
236-
}
237237
else if(esolver_type == "ofdft")
238238
{
239239
return new ESolver_OF();

0 commit comments

Comments
 (0)