Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 65 additions & 29 deletions source/module_elecstate/elecstate_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,36 +31,55 @@ ElecStatePW<T, Device>::ElecStatePW(ModulePW::PW_Basis_K* wfc_basis_in,
template<typename T, typename Device>
ElecStatePW<T, Device>::~ElecStatePW()
{
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
{
delmem_var_op()(this->ctx, this->rho_data);
delete[] this->rho;

if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
{
delmem_complex_op()(this->ctx, this->rhog_data);
delete[] this->rhog;
}
if (get_xc_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
{
delmem_var_op()(this->ctx, this->kin_r_data);
delete[] this->kin_r;
}
}
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") {
delete[] this->rho;
delete[] this->kin_r;
if (PARAM.globalv.use_uspp)
{
delmem_var_op()(this->ctx, this->becsum);
}
delmem_var_op()(this->ctx, becsum);
delmem_complex_op()(this->ctx, this->wfcr);
delmem_complex_op()(this->ctx, this->wfcr_another_spin);
}

template<typename T, typename Device>
void ElecStatePW<T, Device>::init_rho_data()
{
if(this->init_rho) {
if (this->init_rho)
{
return;
}

if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") {

if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
{
this->rho = new Real*[this->charge->nspin];
resmem_var_op()(this->ctx, this->rho_data, this->charge->nspin * this->charge->nrxx);
for (int ii = 0; ii < this->charge->nspin; ii++) {
for (int ii = 0; ii < this->charge->nspin; ii++)
{
this->rho[ii] = this->rho_data + ii * this->charge->nrxx;
}
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
{
this->rhog = new T*[this->charge->nspin];
resmem_complex_op()(this->ctx, this->rhog_data, this->charge->nspin * this->charge->rhopw->npw);
for (int ii = 0; ii < this->charge->nspin; ii++)
{
this->rhog[ii] = this->rhog_data + ii * this->charge->rhopw->npw;
}
}
if (get_xc_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
{
this->kin_r = new Real*[this->charge->nspin];
Expand All @@ -70,8 +89,13 @@ void ElecStatePW<T, Device>::init_rho_data()
}
}
}
else {
else
{
this->rho = reinterpret_cast<Real **>(this->charge->rho);
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
{
this->rhog = reinterpret_cast<T**>(this->charge->rhog);
}
if (get_xc_func_type() == 3 || PARAM.inp.out_elf[0] > 0)
{
this->kin_r = reinterpret_cast<Real **>(this->charge->kin_r);
Expand Down Expand Up @@ -100,19 +124,24 @@ void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
// ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
setmem_var_op()(this->ctx, this->kin_r[is], 0, this->charge->nrxx);
}
}
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
{
setmem_complex_op()(this->ctx, this->rhog[is], 0, this->charge->rhopw->npw);
}
}

for (int ik = 0; ik < psi.get_nk(); ++ik)
{
psi.fix_k(ik);
this->updateRhoK(psi);
}
if (PARAM.globalv.use_uspp)

this->add_usrho(psi);

if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
{
this->add_usrho(psi);
}
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") {
for (int ii = 0; ii < PARAM.inp.nspin; ii++) {
for (int ii = 0; ii < PARAM.inp.nspin; ii++)
{
castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->rho[ii], this->rho[ii], this->charge->nrxx);
if (get_xc_func_type() == 3)
{
Expand Down Expand Up @@ -397,32 +426,39 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
template <typename T, typename Device>
void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
{
this->cal_becsum(psi);
if (PARAM.globalv.use_uspp)
{
this->cal_becsum(psi);
}

// transform soft charge to recip space using smooth grids
T* rhog = nullptr;
resmem_complex_op()(this->ctx, rhog, this->charge->rhopw->npw * PARAM.inp.nspin, "ElecState<PW>::rhog");
setmem_complex_op()(this->ctx, rhog, 0, this->charge->rhopw->npw * PARAM.inp.nspin);
for (int is = 0; is < PARAM.inp.nspin; is++)
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
{
this->rhopw_smooth->real2recip(this->rho[is], &rhog[is * this->charge->rhopw->npw]);
for (int is = 0; is < PARAM.inp.nspin; is++)
{
this->rhopw_smooth->real2recip(this->rho[is], this->rhog[is]);
}
}

// \sum_lm Q_lm(r) \sum_i <psi_i|beta_l><beta_m|psi_i> w_i
// add to the charge density in reciprocal space the part which is due to the US augmentation.
this->addusdens_g(becsum, rhog);
if (PARAM.globalv.use_uspp)
{
this->addusdens_g(becsum, rhog);
}

// transform back to real space using dense grids
for (int is = 0; is < PARAM.inp.nspin; is++)
if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
{
this->charge->rhopw->recip2real(&rhog[is * this->charge->rhopw->npw], this->rho[is]);
for (int is = 0; is < PARAM.inp.nspin; is++)
{
this->charge->rhopw->recip2real(this->rhog[is], this->rho[is]);
}
}

delmem_complex_op()(this->ctx, rhog);
}

template <typename T, typename Device>
void ElecStatePW<T, Device>::addusdens_g(const Real* becsum, T* rhog)
void ElecStatePW<T, Device>::addusdens_g(const Real* becsum, T** rhog)
{
const T one{1, 0};
const T zero{0, 0};
Expand Down Expand Up @@ -506,7 +542,7 @@ void ElecStatePW<T, Device>::addusdens_g(const Real* becsum, T* rhog)
this->ppcell->radial_fft_q(this->ctx, npw, ih, jh, it, qmod, ylmk0, qgm);
for (int ig = 0; ig < npw; ig++)
{
rhog[is * npw + ig] += qgm[ig] * aux2[ijh * npw + ig];
rhog[is][ig] += qgm[ig] * aux2[ijh * npw + ig];
}
ijh++;
}
Expand Down
10 changes: 6 additions & 4 deletions source/module_elecstate/elecstate_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ class ElecStatePW : public ElecState

//! init rho_data and kin_r_data
void init_rho_data();
Real** rho = nullptr;
Real** kin_r = nullptr; //[Device] [spin][nrxx] rho and kin_r
Real** rho = nullptr; // [Device] [spin][nrxx] rho
T** rhog = nullptr; // [Device] [spin][nrxx] rhog
Real** kin_r = nullptr; // [Device] [spin][nrxx] kin_r

protected:

Expand All @@ -70,15 +71,16 @@ class ElecStatePW : public ElecState

//! Non-local pseudopotentials
//! \sum_lm Q_lm(r) \sum_i <psi_i|beta_l><beta_m|psi_i> w_i
void addusdens_g(const Real* becsum, T* rhog);
void addusdens_g(const Real* becsum, T** rhog);

Device * ctx = {};

bool init_rho = false;

mutable T* vkb = nullptr;

Real* rho_data = nullptr;
Real* rho_data = nullptr;
T* rhog_data = nullptr;
Real* kin_r_data = nullptr;
T* wfcr = nullptr;
T* wfcr_another_spin = nullptr;
Expand Down
1 change: 1 addition & 0 deletions tests/integrate/101_PW_upf201_Al_pseudopots/INPUT
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pseudo_dir ../../PP_ORB

#Parameters (2.Iteration)
ecutwfc 30
ecutrho 160
scf_thr 1e-9
scf_nmax 100

Expand Down
2 changes: 2 additions & 0 deletions tests/integrate/101_PW_upf201_Al_pseudopots/README
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ This test for:
*upf201 pseudopotential
*mixing_type broyden-kerker
*mixing_beta 0.7
*ecutwfc 30
*ecutrho 160
2 changes: 1 addition & 1 deletion tests/integrate/101_PW_upf201_Al_pseudopots/jd
Original file line number Diff line number Diff line change
@@ -1 +1 @@
test upf201 pseudopotential, symmetry=on
test upf201 pseudopotential with ecutrho/ecutwfc > 4, symmetry=on
6 changes: 3 additions & 3 deletions tests/integrate/101_PW_upf201_Al_pseudopots/result.ref
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
etotref -57.66876692615671
etotperatomref -57.6687669262
etotref -57.66876684738178
etotperatomref -57.6687668474
pointgroupref O_h
spacegroupref O_h
nksibzref 3
totaltimeref
totaltimeref 0.29
Loading