Skip to content
66 changes: 36 additions & 30 deletions source/module_elecstate/elecstate_pw_sdft.cpp
Original file line number Diff line number Diff line change
@@ -1,41 +1,47 @@
#include "./elecstate_pw_sdft.h"

#include "module_base/global_function.h"
#include "module_base/global_variable.h"
#include "module_parameter/parameter.h"
#include "module_base/timer.h"
#include "module_base/global_function.h"
#include "module_hamilt_general/module_xc/xc_functional.h"
#include "module_parameter/parameter.h"
namespace elecstate
{
void ElecStatePW_SDFT::psiToRho(const psi::Psi<std::complex<double>>& psi)

template <typename T, typename Device>
void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T>& psi)
{
ModuleBase::TITLE(this->classname, "psiToRho");
ModuleBase::timer::tick(this->classname, "psiToRho");
for (int is = 0; is < PARAM.inp.nspin; is++)
{
ModuleBase::TITLE(this->classname, "psiToRho");
ModuleBase::timer::tick(this->classname, "psiToRho");
for(int is=0; is < PARAM.inp.nspin; is++)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
if (XC_Functional::get_func_type() == 3)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
}
}

if(GlobalV::MY_STOGROUP == 0)
{
this->calEBand();
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
if (XC_Functional::get_func_type() == 3)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
}
}

if (GlobalV::MY_STOGROUP == 0)
{
this->calEBand();

for(int is=0; is<PARAM.inp.nspin; is++)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
}
for (int is = 0; is < PARAM.inp.nspin; is++)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx);
}

for (int ik = 0; ik < psi.get_nk(); ++ik)
{
psi.fix_k(ik);
this->updateRhoK(psi);
}
this->parallelK();
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
psi.fix_k(ik);
this->updateRhoK(psi);
}
ModuleBase::timer::tick(this->classname, "psiToRho");
return;
this->parallelK();
}
}
ModuleBase::timer::tick(this->classname, "psiToRho");
return;
}

// template class ElecStatePW_SDFT<std::complex<float>, base_device::DEVICE_CPU>;
template class ElecStatePW_SDFT<std::complex<double>, base_device::DEVICE_CPU>;
} // namespace elecstate
36 changes: 19 additions & 17 deletions source/module_elecstate/elecstate_pw_sdft.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@
#include "elecstate_pw.h"
namespace elecstate
{
class ElecStatePW_SDFT : public ElecStatePW<std::complex<double>>
template <typename T, typename Device>
class ElecStatePW_SDFT : public ElecStatePW<T, Device>
{
public:
ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in,
Charge* chg_in,
K_Vectors* pkv_in,
UnitCell* ucell_in,
pseudopot_cell_vnl* ppcell_in,
ModulePW::PW_Basis* rhodpw_in,
ModulePW::PW_Basis* rhopw_in,
ModulePW::PW_Basis_Big* bigpw_in)
: ElecStatePW<T,
Device>(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in)
{
public:
ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in,
Charge* chg_in,
K_Vectors* pkv_in,
UnitCell* ucell_in,
pseudopot_cell_vnl* ppcell_in,
ModulePW::PW_Basis* rhodpw_in,
ModulePW::PW_Basis* rhopw_in,
ModulePW::PW_Basis_Big* bigpw_in)
: ElecStatePW(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in)
{
this->classname = "ElecStatePW_SDFT";
}
virtual void psiToRho(const psi::Psi<std::complex<double>>& psi) override;
};
}
this->classname = "ElecStatePW_SDFT";
}
virtual void psiToRho(const psi::Psi<T>& psi) override;
};
} // namespace elecstate
#endif
15 changes: 11 additions & 4 deletions source/module_esolver/esolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
return new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_CPU>();
}
}
else if (esolver_type == "sdft_pw")
{
// if (PARAM.inp.precision == "single")
// {
// return new ESolver_SDFT_PW<std::complex<float>, base_device::DEVICE_CPU>();
// }
// else
// {
return new ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>();
// }
}
#ifdef __LCAO
else if (esolver_type == "ksdft_lip")
{
Expand Down Expand Up @@ -230,10 +241,6 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
return p_esolver_lr;
}
#endif
else if (esolver_type == "sdft_pw")
{
return new ESolver_SDFT_PW();
}
else if(esolver_type == "ofdft")
{
return new ESolver_OF();
Expand Down
3 changes: 3 additions & 0 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ void ESolver_KS_PW<T, Device>::update_pot(const int istep, const int iter)
}
this->pelec->pot->update_from_charge(this->pelec->charge, &GlobalC::ucell);
this->pelec->f_en.descf = this->pelec->cal_delta_escf();
#ifdef __MPI
MPI_Bcast(&(this->pelec->f_en.descf), 1, MPI_DOUBLE, 0, PARAPW_WORLD);
#endif
}
else
{
Expand Down
Loading
Loading