Skip to content

Commit 4294ee7

Browse files
authored
Feature: make init stochastic WF support GPU (#5365)
* add bcast for energy * add template <Device> for sdft * fix compile * fix compile without mpi * add template for precision * Make init stochastic WF support GPU * fix compile without mpi * fix compile after merge * fix bug of init_com_orbitals
1 parent 705dad7 commit 4294ee7

26 files changed

+979
-776
lines changed
Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,47 @@
11
#include "./elecstate_pw_sdft.h"
2+
3+
#include "module_base/global_function.h"
24
#include "module_base/global_variable.h"
3-
#include "module_parameter/parameter.h"
45
#include "module_base/timer.h"
5-
#include "module_base/global_function.h"
66
#include "module_hamilt_general/module_xc/xc_functional.h"
7+
#include "module_parameter/parameter.h"
78
namespace elecstate
89
{
9-
void ElecStatePW_SDFT::psiToRho(const psi::Psi<std::complex<double>>& psi)
10+
11+
template <typename T, typename Device>
12+
void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T>& psi)
13+
{
14+
ModuleBase::TITLE(this->classname, "psiToRho");
15+
ModuleBase::timer::tick(this->classname, "psiToRho");
16+
for (int is = 0; is < PARAM.inp.nspin; is++)
1017
{
11-
ModuleBase::TITLE(this->classname, "psiToRho");
12-
ModuleBase::timer::tick(this->classname, "psiToRho");
13-
for(int is=0; is < PARAM.inp.nspin; is++)
14-
{
15-
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
16-
if (XC_Functional::get_func_type() == 3)
17-
{
18-
ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
19-
}
20-
}
21-
22-
if(GlobalV::MY_STOGROUP == 0)
23-
{
24-
this->calEBand();
18+
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
19+
if (XC_Functional::get_func_type() == 3)
20+
{
21+
ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
22+
}
23+
}
24+
25+
if (GlobalV::MY_STOGROUP == 0)
26+
{
27+
this->calEBand();
2528

26-
for(int is=0; is<PARAM.inp.nspin; is++)
27-
{
28-
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
29-
}
29+
for (int is = 0; is < PARAM.inp.nspin; is++)
30+
{
31+
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
32+
}
3033

31-
for (int ik = 0; ik < psi.get_nk(); ++ik)
32-
{
33-
psi.fix_k(ik);
34-
this->updateRhoK(psi);
35-
}
36-
this->parallelK();
34+
for (int ik = 0; ik < psi.get_nk(); ++ik)
35+
{
36+
psi.fix_k(ik);
37+
this->updateRhoK(psi);
3738
}
38-
ModuleBase::timer::tick(this->classname, "psiToRho");
39-
return;
39+
this->parallelK();
4040
}
41-
}
41+
ModuleBase::timer::tick(this->classname, "psiToRho");
42+
return;
43+
}
44+
45+
// template class ElecStatePW_SDFT<std::complex<float>, base_device::DEVICE_CPU>;
46+
template class ElecStatePW_SDFT<std::complex<double>, base_device::DEVICE_CPU>;
47+
} // namespace elecstate

source/module_elecstate/elecstate_pw_sdft.h

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@
33
#include "elecstate_pw.h"
44
namespace elecstate
55
{
6-
class ElecStatePW_SDFT : public ElecStatePW<std::complex<double>>
6+
template <typename T, typename Device>
7+
class ElecStatePW_SDFT : public ElecStatePW<T, Device>
8+
{
9+
public:
10+
ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in,
11+
Charge* chg_in,
12+
K_Vectors* pkv_in,
13+
UnitCell* ucell_in,
14+
pseudopot_cell_vnl* ppcell_in,
15+
ModulePW::PW_Basis* rhodpw_in,
16+
ModulePW::PW_Basis* rhopw_in,
17+
ModulePW::PW_Basis_Big* bigpw_in)
18+
: ElecStatePW<T,
19+
Device>(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in)
720
{
8-
public:
9-
ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in,
10-
Charge* chg_in,
11-
K_Vectors* pkv_in,
12-
UnitCell* ucell_in,
13-
pseudopot_cell_vnl* ppcell_in,
14-
ModulePW::PW_Basis* rhodpw_in,
15-
ModulePW::PW_Basis* rhopw_in,
16-
ModulePW::PW_Basis_Big* bigpw_in)
17-
: ElecStatePW(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in)
18-
{
19-
this->classname = "ElecStatePW_SDFT";
20-
}
21-
virtual void psiToRho(const psi::Psi<std::complex<double>>& psi) override;
22-
};
23-
}
21+
this->classname = "ElecStatePW_SDFT";
22+
}
23+
virtual void psiToRho(const psi::Psi<T>& psi) override;
24+
};
25+
} // namespace elecstate
2426
#endif

source/module_esolver/esolver.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
153153
return new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_CPU>();
154154
}
155155
}
156+
else if (esolver_type == "sdft_pw")
157+
{
158+
// if (PARAM.inp.precision == "single")
159+
// {
160+
// return new ESolver_SDFT_PW<std::complex<float>, base_device::DEVICE_CPU>();
161+
// }
162+
// else
163+
// {
164+
return new ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>();
165+
// }
166+
}
156167
#ifdef __LCAO
157168
else if (esolver_type == "ksdft_lip")
158169
{
@@ -230,10 +241,6 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
230241
return p_esolver_lr;
231242
}
232243
#endif
233-
else if (esolver_type == "sdft_pw")
234-
{
235-
return new ESolver_SDFT_PW();
236-
}
237244
else if(esolver_type == "ofdft")
238245
{
239246
return new ESolver_OF();

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,9 @@ void ESolver_KS_PW<T, Device>::update_pot(const int istep, const int iter)
423423
}
424424
this->pelec->pot->update_from_charge(this->pelec->charge, &GlobalC::ucell);
425425
this->pelec->f_en.descf = this->pelec->cal_delta_escf();
426+
#ifdef __MPI
427+
MPI_Bcast(&(this->pelec->f_en.descf), 1, MPI_DOUBLE, 0, PARAPW_WORLD);
428+
#endif
426429
}
427430
else
428431
{

0 commit comments

Comments
 (0)