Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ namespace ModuleESolver
}

hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge,ucell.tpiba,ucell.nat);

// add exx
#ifdef __EXX
Expand Down
4 changes: 3 additions & 1 deletion source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,9 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(UnitCell& ucell,
this->pelec->ekb.c,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,
skip_charge);
skip_charge,
ucell.tpiba,
ucell.nat);

Symmetry_rho srho;
for (int is = 0; is < PARAM.inp.nspin; is++)
Expand Down
23 changes: 14 additions & 9 deletions source/module_hsolver/hsolver_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace hsolver

#ifdef USE_PAW
template <typename T>
void HSolverLIP<T>::paw_func_in_kloop(const int ik)
void HSolverLIP<T>::paw_func_in_kloop(const int ik,const double tpiba)
{
if (PARAM.inp.use_paw)
{
Expand Down Expand Up @@ -64,7 +64,7 @@ void HSolverLIP<T>::paw_func_in_kloop(const int ik)
this->wfc_basis->get_ig2iy(ik).data(),
this->wfc_basis->get_ig2iz(ik).data(),
(const double**)kpg,
GlobalC::ucell.tpiba,
tpiba,
(const double**)gcar);

std::vector<double>().swap(kpt);
Expand All @@ -83,7 +83,10 @@ void HSolverLIP<T>::paw_func_in_kloop(const int ik)
}

template <typename T>
void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState* pes)
void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi,
elecstate::ElecState* pes,
const double tpiba,
const int nat)
{
if (PARAM.inp.use_paw)
{
Expand Down Expand Up @@ -131,7 +134,7 @@ void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState*
this->wfc_basis->get_ig2iy(ik).data(),
this->wfc_basis->get_ig2iz(ik).data(),
(const double**)kpg,
GlobalC::ucell.tpiba,
tpiba,
(const double**)gcar);

std::vector<double>().swap(kpt);
Expand Down Expand Up @@ -164,7 +167,7 @@ void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState*
{
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
for (int iat = 0; iat < nat; iat++)
{
GlobalC::paw_cell.set_rhoij(iat,
nrhoijsel[iat],
Expand All @@ -176,7 +179,7 @@ void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState*
#else
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
for (int iat = 0; iat < nat; iat++)
{
GlobalC::paw_cell.set_rhoij(iat,
nrhoijsel[iat],
Expand All @@ -201,7 +204,9 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
psi::Psi<T>& psi, // ESolver_KS_PW::kspw_psi
elecstate::ElecState* pes, // ESolver_KS_PW::pes
psi::Psi<T>& transform,
const bool skip_charge)
const bool skip_charge,
const double tpiba,
const int nat)
{
ModuleBase::TITLE("HSolverLIP", "solve");
ModuleBase::timer::tick("HSolverLIP", "solve");
Expand All @@ -212,7 +217,7 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
pHamilt->updateHk(ik);

#ifdef USE_PAW
this->paw_func_in_kloop(ik);
this->paw_func_in_kloop(ik,tpiba);
#endif

psi.fix_k(ik);
Expand Down Expand Up @@ -282,7 +287,7 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
reinterpret_cast<elecstate::ElecStatePW<T>*>(pes)->psiToRho(psi);

#ifdef USE_PAW
this->paw_func_after_kloop(psi, pes);
this->paw_func_after_kloop(psi, pes,tpiba,nat);
#endif

ModuleBase::timer::tick("HSolverLIP", "solve");
Expand Down
12 changes: 9 additions & 3 deletions source/module_hsolver/hsolver_lcaopw.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,21 @@ class HSolverLIP
psi::Psi<T>& psi,
elecstate::ElecState* pes,
psi::Psi<T>& transform,
const bool skip_charge);
const bool skip_charge,
const double tpiba,
const int nat);

private:
ModulePW::PW_Basis_K* wfc_basis;

#ifdef USE_PAW
void paw_func_in_kloop(const int ik);
void paw_func_in_kloop(const int ik,
const double tpiba);

void paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState* pes);
void paw_func_after_kloop(psi::Psi<T>& psi,
elecstate::ElecState* pes,
const double tpiba,
const int nat);
#endif
};

Expand Down
24 changes: 15 additions & 9 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace hsolver

#ifdef USE_PAW
template <typename T, typename Device>
void HSolverPW<T, Device>::paw_func_in_kloop(const int ik)
void HSolverPW<T, Device>::paw_func_in_kloop(const int ik,
const double tpiba)
{
if (this->use_paw)
{
Expand Down Expand Up @@ -68,7 +69,7 @@ void HSolverPW<T, Device>::paw_func_in_kloop(const int ik)
this->wfc_basis->get_ig2iy(ik).data(),
this->wfc_basis->get_ig2iz(ik).data(),
(const double**)kpg,
GlobalC::ucell.tpiba,
tpiba,
(const double**)gcar);

std::vector<double>().swap(kpt);
Expand Down Expand Up @@ -96,7 +97,10 @@ void HSolverPW<T, Device>::call_paw_cell_set_currentk(const int ik)
}

template <typename T, typename Device>
void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes)
void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,
const double tpiba,
const int nat)
{
if (this->use_paw)
{
Expand Down Expand Up @@ -144,7 +148,7 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
this->wfc_basis->get_ig2iy(ik).data(),
this->wfc_basis->get_ig2iz(ik).data(),
(const double**)kpg,
GlobalC::ucell.tpiba,
tpiba,
(const double**)gcar);

std::vector<double>().swap(kpt);
Expand Down Expand Up @@ -177,7 +181,7 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
{
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
for (int iat = 0; iat < nat; iat++)
{
GlobalC::paw_cell.set_rhoij(iat,
nrhoijsel[iat],
Expand All @@ -189,7 +193,7 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
#else
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
for (int iat = 0; iat < nat; iat++)
{
GlobalC::paw_cell.set_rhoij(iat,
nrhoijsel[iat],
Expand Down Expand Up @@ -255,7 +259,9 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
double* out_eigenvalues,
const int rank_in_pool_in,
const int nproc_in_pool_in,
const bool skip_charge)
const bool skip_charge,
const double tpiba,
const int nat)
{
ModuleBase::TITLE("HSolverPW", "solve");
ModuleBase::timer::tick("HSolverPW", "solve");
Expand All @@ -282,7 +288,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
pHamilt->updateHk(ik);

#ifdef USE_PAW
this->paw_func_in_kloop(ik);
this->paw_func_in_kloop(ik,tpiba);
#endif

/// update psi pointer for each k point
Expand Down Expand Up @@ -341,7 +347,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);

#ifdef USE_PAW
this->paw_func_after_kloop(psi, pes);
this->paw_func_after_kloop(psi, pes,tpiba,nat);
#endif

ModuleBase::timer::tick("HSolverPW", "solve");
Expand Down
9 changes: 6 additions & 3 deletions source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class HSolverPW
double* out_eigenvalues,
const int rank_in_pool_in,
const int nproc_in_pool_in,
const bool skip_charge);
const bool skip_charge,
const double tpiba,
const int nat);

protected:
// diago caller
Expand Down Expand Up @@ -89,11 +91,12 @@ class HSolverPW
std::vector<double> ethr_band;

#ifdef USE_PAW
void paw_func_in_kloop(const int ik);
void paw_func_in_kloop(const int ik,
const double tpiba);

void call_paw_cell_set_currentk(const int ik);

void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes);
void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes,const double tpiba,const int nat);
#endif
};

Expand Down
5 changes: 2 additions & 3 deletions source/module_hsolver/test/test_hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ TEST_F(TestHSolverPW, SolveLcaoInPW) {
= hsolver::HSolverLIP<std::complex<float>>(&pwbk);
hsolver::HSolverLIP<std::complex<double>> hs_d_lip
= hsolver::HSolverLIP<std::complex<double>>(&pwbk);
hs_f_lip.solve(&hamilt_test_f, psi_test_cf, &elecstate_test,
transform_test_cf, true);
hs_f_lip.solve(&hamilt_test_f, psi_test_cf, &elecstate_test,transform_test_cf, true,0.0,0);
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<std::complex<float>>::avg_iter, 0.0);
for (int i = 0; i < psi_test_cf.size(); i++)
{
Expand All @@ -261,7 +260,7 @@ TEST_F(TestHSolverPW, SolveLcaoInPW) {

elecstate_test.ekb.c[0] = 1.0;
elecstate_test.ekb.c[1] = 2.0;
hs_d_lip.solve(&hamilt_test_d, psi_test_cd, &elecstate_test, transform_test_cd, true);
hs_d_lip.solve(&hamilt_test_d, psi_test_cd, &elecstate_test, transform_test_cd, true,0.0,0);
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<std::complex<double>>::avg_iter, 0.0);
for (int i = 0; i < psi_test_cd.size(); i++)
{
Expand Down
Loading