Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ void XC_Functional::gradcorr(double &etxc, double &vtxc, ModuleBase::matrix &v,
}

gdr2 = new ModuleBase::Vector3<double>[rhopw->nrxx];
h2 = new ModuleBase::Vector3<double>[rhopw->nrxx];
if(!is_stress) h2 = new ModuleBase::Vector3<double>[rhopw->nrxx];

XC_Functional::grad_rho( rhogsum1 , gdr1, rhopw, ucell->tpiba);
XC_Functional::grad_rho( rhogsum2 , gdr2, rhopw, ucell->tpiba);
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_pw/hamilt_pwdft/forces_nl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void Forces<FPTYPE, Device>::cal_force_nl(ModuleBase::matrix& forcenl,
break;
}
}
const int npm = ucell_in.get_npol() * nbands_occ;
const int npm = nbands_occ;
// calculate becp = <psi|beta> for all beta functions
nl_tools.cal_becp(ik, npm);
for (int ipol = 0; ipol < 3; ipol++)
Expand Down
180 changes: 114 additions & 66 deletions source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,6 @@ void FS_Nonlocal_tools<FPTYPE, Device>::allocate_memory(const ModuleBase::matrix
this->atom_na = h_atom_na.data();
this->ppcell_vkb = this->nlpp_->vkb.c;
}

// prepare the memory of the becp and dbecp:
// becp: <Beta(nkb,npw)|psi(nbnd,npw)>
// dbecp: <dBeta(nkb,npw)/dG|psi(nbnd,npw)>
resmem_complex_op()(this->ctx, becp, this->nbands * nkb, "Stress::becp");
resmem_complex_op()(this->ctx, dbecp, 6 * this->nbands * nkb, "Stress::dbecp");
}

template <typename FPTYPE, typename Device>
Expand Down Expand Up @@ -163,9 +157,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
{
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_becp");
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_becp");
const int npol = this->ucell_->get_npol();
const int size_becp = this->nbands * npol * this->nkb;
const int size_becp_act = npm * npol * this->nkb;
if (this->becp == nullptr)
{
resmem_complex_op()(this->ctx, becp, this->nbands * this->nkb);
resmem_complex_op()(this->ctx, becp, size_becp);
}

// prepare math tools
Expand Down Expand Up @@ -249,11 +246,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
}
const char transa = 'C';
const char transb = 'N';
int npm_npol = npm * npol;
gemm_op()(this->ctx,
transa,
transb,
nkb,
npm,
npm_npol,
npw,
&ModuleBase::ONE,
ppcell_vkb,
Expand All @@ -268,15 +266,15 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
if (this->device == base_device::GpuDevice)
{
std::complex<FPTYPE>* h_becp = nullptr;
resmem_complex_h_op()(this->cpu_ctx, h_becp, this->nbands * nkb);
syncmem_complex_d2h_op()(this->cpu_ctx, this->ctx, h_becp, becp, this->nbands * nkb);
Parallel_Reduce::reduce_pool(h_becp, this->nbands * nkb);
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, becp, h_becp, this->nbands * nkb);
resmem_complex_h_op()(this->cpu_ctx, h_becp, size_becp_act);
syncmem_complex_d2h_op()(this->cpu_ctx, this->ctx, h_becp, becp, size_becp_act);
Parallel_Reduce::reduce_pool(h_becp, size_becp_act);
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, becp, h_becp, size_becp_act);
delmem_complex_h_op()(this->cpu_ctx, h_becp);
}
else
{
Parallel_Reduce::reduce_pool(becp, this->nbands * this->nkb);
Parallel_Reduce::reduce_pool(becp, size_becp_act);
}
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_becp");
}
Expand All @@ -287,9 +285,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
{
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_dbecp_s");
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_s");
const int npol = this->ucell_->get_npol();
const int size_becp = this->nbands * npol * this->nkb;
const int npm_npol = npm * npol;
if (this->dbecp == nullptr)
{
resmem_complex_op()(this->ctx, dbecp, this->nbands * this->nkb);
resmem_complex_op()(this->ctx, dbecp, size_becp);
}

// prepare math tools
Expand Down Expand Up @@ -401,7 +402,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
transa,
transb,
nkb,
npm,
npm_npol,
npw,
&ModuleBase::ONE,
ppcell_vkb,
Expand All @@ -412,43 +413,68 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
dbecp,
nkb);
// calculate stress for target (ipol, jpol)
const int current_spin = this->kv_->isk[ik];
cal_stress_nl_op()(this->ctx,
nondiagonal,
ipol,
jpol,
nkb,
npm,
this->ntype,
current_spin, // uspp only
this->nbands,
ik,
this->nlpp_->deeq.getBound2(),
this->nlpp_->deeq.getBound3(),
this->nlpp_->deeq.getBound4(),
atom_nh,
atom_na,
d_wg,
d_ekb,
qq_nt,
deeq,
becp,
dbecp,
stress);
if(npol == 1)
{
const int current_spin = this->kv_->isk[ik];
cal_stress_nl_op()(this->ctx,
nondiagonal,
ipol,
jpol,
nkb,
npm,
this->ntype,
current_spin, // uspp only
this->nlpp_->deeq.getBound2(),
this->nlpp_->deeq.getBound3(),
this->nlpp_->deeq.getBound4(),
atom_nh,
atom_na,
d_wg + this->nbands * ik,
d_ekb + this->nbands * ik,
qq_nt,
deeq,
becp,
dbecp,
stress);
}
else
{
cal_stress_nl_op()(this->ctx,
ipol,
jpol,
nkb,
npm,
this->ntype,
this->nlpp_->deeq_nc.getBound2(),
this->nlpp_->deeq_nc.getBound3(),
this->nlpp_->deeq_nc.getBound4(),
atom_nh,
atom_na,
d_wg + this->nbands * ik,
d_ekb + this->nbands * ik,
qq_nt,
this->nlpp_->template get_deeq_nc_data<FPTYPE>(),
becp,
dbecp,
stress);
}
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_s");
}

template <typename FPTYPE, typename Device>
void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
{
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_dbecp_s");
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_dbecp_f");
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_f");
const int npol = this->ucell_->get_npol();
const int npm_npol = npm * npol;
const int size_becp = this->nbands * npol * this->nkb;
if (this->dbecp == nullptr)
{
resmem_complex_op()(this->ctx, dbecp, 3 * this->nbands * this->nkb);
resmem_complex_op()(this->ctx, dbecp, 3 * size_becp);
}

std::complex<FPTYPE>* dbecp_ptr = this->dbecp + ipol * this->nbands * this->nkb;
std::complex<FPTYPE>* dbecp_ptr = this->dbecp + ipol * size_becp;
const std::complex<FPTYPE>* vkb_ptr = this->ppcell_vkb;
std::complex<FPTYPE>* vkb_deri_ptr = this->ppcell_vkb;

Expand Down Expand Up @@ -481,7 +507,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
transa,
transb,
this->nkb,
npm,
npm_npol,
npw,
&ModuleBase::ONE,
vkb_deri_ptr,
Expand All @@ -491,7 +517,6 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
&ModuleBase::ZERO,
dbecp_ptr,
nkb);

this->revert_vkb(npw, ipol);
this->pre_ik_f = ik;
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_f");
Expand Down Expand Up @@ -634,29 +659,52 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_force(int ik, int npm, FPTYPE* force
const int current_spin = this->kv_->isk[ik];
const int force_nc = 3;
// calculate the force
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
nondiagonal,
npm,
this->nbands,
this->ntype,
current_spin,
this->nlpp_->deeq.getBound2(),
this->nlpp_->deeq.getBound3(),
this->nlpp_->deeq.getBound4(),
force_nc,
this->nbands,
ik,
nkb,
atom_nh,
atom_na,
this->ucell_->tpiba,
d_wg,
d_ekb,
qq_nt,
deeq,
becp,
dbecp,
force);
if(this->ucell_->get_npol() == 1)
{
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
nondiagonal,
npm,
this->ntype,
current_spin,
this->nlpp_->deeq.getBound2(),
this->nlpp_->deeq.getBound3(),
this->nlpp_->deeq.getBound4(),
force_nc,
this->nbands,
nkb,
atom_nh,
atom_na,
this->ucell_->tpiba,
d_wg + this->nbands * ik,
d_ekb + this->nbands * ik,
qq_nt,
deeq,
becp,
dbecp,
force);
}
else
{
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
npm,
this->ntype,
this->nlpp_->deeq_nc.getBound2(),
this->nlpp_->deeq_nc.getBound3(),
this->nlpp_->deeq_nc.getBound4(),
force_nc,
this->nbands,
nkb,
atom_nh,
atom_na,
this->ucell_->tpiba,
d_wg + this->nbands * ik,
d_ekb + this->nbands * ik,
qq_nt,
this->nlpp_->template get_deeq_nc_data<FPTYPE>(),
becp,
dbecp,
force);
}
}

// template instantiation
Expand Down
Loading
Loading