Skip to content

Commit 9f41593

Browse files
authored
Merge branch 'develop' into fft12
2 parents e67db12 + 28d8902 commit 9f41593

File tree

9 files changed

+272
-82
lines changed

9 files changed

+272
-82
lines changed

source/module_cell/klist.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ class K_Vectors
140140
{
141141
this->nkstot_full = value;
142142
}
143+
144+
bool get_is_mp() const
145+
{
146+
return is_mp;
147+
}
148+
143149
std::vector<int> ik2iktot; ///<[nks] map ik to the global index of k points
144150

145151
private:

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
#include "module_base/kernels/dsp/dsp_connector.h"
5050
#endif
5151

52-
52+
#include <chrono>
5353

5454
namespace ModuleESolver
5555
{
@@ -578,13 +578,18 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
578578
{
579579
if (conv_esolver)
580580
{
581+
auto start = std::chrono::high_resolution_clock::now();
582+
exx_helper.set_firstiter(false);
581583
exx_helper.set_psi(this->kspw_psi);
582584

583585
conv_esolver = exx_helper.exx_after_converge(iter);
584586

585587
if (!conv_esolver)
586588
{
587-
std::cout << " Setting Psi for EXX PW Inner Loop" << std::endl;
589+
auto duration = std::chrono::high_resolution_clock::now() - start;
590+
std::cout << " Setting Psi for EXX PW Inner Loop took "
591+
<< std::chrono::duration_cast<std::chrono::milliseconds>(duration).count() / 1000.0
592+
<< "s" << std::endl;
588593
exx_helper.op_exx->first_iter = false;
589594
XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func);
590595
update_pot(ucell, istep, iter, conv_esolver);

source/module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct Exx_Helper
1818
Exx_Helper() = default;
1919
OperatorEXX *op_exx = nullptr;
2020

21-
void set_firstiter() { first_iter = true; }
21+
void set_firstiter(bool flag = true) { first_iter = flag; }
2222
void set_wg(const ModuleBase::matrix *wg_) { wg = wg_; }
2323
void set_psi(psi::Psi<T, Device> *psi_);
2424

@@ -34,7 +34,7 @@ struct Exx_Helper
3434
double cal_exx_energy(psi::Psi<T, Device> *psi_);
3535

3636
private:
37-
bool first_iter;
37+
bool first_iter = false;
3838
psi::Psi<T, Device> *psi = nullptr;
3939
const ModuleBase::matrix *wg = nullptr;
4040

source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp

Lines changed: 117 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ OperatorEXXPW<T, Device>::OperatorEXXPW(const int* isk_in,
5353
const UnitCell *ucell)
5454
: isk(isk_in), wfcpw(wfcpw_in), rhopw(rhopw_in), kv(kv_in), ucell(ucell)
5555
{
56-
56+
gamma_extrapolation = PARAM.inp.exx_gamma_extrapolation;
57+
if (!kv_in->get_is_mp())
58+
{
59+
gamma_extrapolation = false;
60+
}
5761
if (GlobalV::KPAR != 1)
5862
{
5963
// GlobalV::ofs_running << "EXX Calculation does not support k-point parallelism" << std::endl;
@@ -269,7 +273,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
269273
setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx);
270274

271275
} // end of iq
272-
auto h_psi_nk = tmhpsi + n_iband * nbasis;
276+
T* h_psi_nk = tmhpsi + n_iband * nbasis;
273277
Real hybrid_alpha = GlobalC::exx_info.info_global.hybrid_alpha;
274278
wfcpw->real_to_recip(ctx, h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha);
275279
setmem_complex_op()(h_psi_real, 0, rhopw->nrxx);
@@ -293,7 +297,7 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
293297

294298
// std::cout << "act_op_ace" << std::endl;
295299
// hpsi += -Xi^\dagger * Xi * psi
296-
auto Xi_ace = Xi_ace_k[this->ik];
300+
T* Xi_ace = Xi_ace_k[this->ik];
297301
int nbands_tot = psi.get_nbands();
298302
int nbasis_max = psi.get_nbasis();
299303
// T* hpsi = nullptr;
@@ -344,13 +348,14 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
344348
// vec_add_vec_complex_op()(this->ctx, nbands * nbasis, tmhpsi, hpsi, -1, tmhpsi, 1);
345349
// delmem_complex_op()(hpsi);
346350
delmem_complex_op()(Xi_psi);
347-
ModuleBase::timer::tick("OperatorEXXPW", "act_op");
351+
ModuleBase::timer::tick("OperatorEXXPW", "act_op_ace");
348352

349353
}
350354

351355
template <typename T, typename Device>
352356
void OperatorEXXPW<T, Device>::construct_ace() const
353357
{
358+
ModuleBase::timer::tick("OperatorEXXPW", "construct_ace");
354359
// int nkb = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk();
355360
int nbands = psi.get_nbands();
356361
int nbasis = psi.get_nbasis();
@@ -494,6 +499,7 @@ void OperatorEXXPW<T, Device>::construct_ace() const
494499
}
495500

496501
*ik_ = ik_save;
502+
ModuleBase::timer::tick("OperatorEXXPW", "construct_ace");
497503

498504
}
499505

@@ -539,7 +545,8 @@ void OperatorEXXPW<T, Device>::multiply_potential(T *density_recip, int ik, int
539545
#endif
540546
for (int ig = 0; ig < npw; ig++)
541547
{
542-
density_recip[ig] *= pot[ik * nks * npw + iq * npw + ig];
548+
int ig_kq = ik * nks * npw + iq * npw + ig;
549+
density_recip[ig] *= pot[ig_kq];
543550

544551
}
545552

@@ -551,7 +558,7 @@ const T *OperatorEXXPW<T, Device>::get_pw(const int m, const int iq) const
551558
{
552559
// return pws[iq].get() + m * wfcpw->npwk[iq];
553560
psi.fix_kb(iq, m);
554-
auto psi_mq = psi.get_pointer();
561+
T* psi_mq = psi.get_pointer();
555562
return psi_mq;
556563
}
557564

@@ -580,22 +587,57 @@ OperatorEXXPW<T, Device>::OperatorEXXPW(const OperatorEXXPW<T_in, Device_in> *op
580587
template <typename T, typename Device>
581588
void OperatorEXXPW<T, Device>::get_potential() const
582589
{
590+
Real nqs_half1 = 0.5 * kv->nmp[0];
591+
Real nqs_half2 = 0.5 * kv->nmp[1];
592+
Real nqs_half3 = 0.5 * kv->nmp[2];
593+
583594
int nks = wfcpw->nks, npw = rhopw->npw;
584595
double tpiba2 = tpiba * tpiba;
585596
// calculate the pot
586597
for (int ik = 0; ik < nks; ik++)
587598
{
588599
for (int iq = 0; iq < nks; iq++)
589600
{
590-
auto k = wfcpw->kvec_c[ik];
591-
auto q = wfcpw->kvec_c[iq];
601+
const ModuleBase::Vector3<double> k_c = wfcpw->kvec_c[ik];
602+
const ModuleBase::Vector3<double> k_d = wfcpw->kvec_d[ik];
603+
const ModuleBase::Vector3<double> q_c = wfcpw->kvec_c[iq];
604+
const ModuleBase::Vector3<double> q_d = wfcpw->kvec_d[iq];
592605

593606
#ifdef _OPENMP
594607
#pragma omp parallel for schedule(static)
595608
#endif
596609
for (int ig = 0; ig < rhopw->npw; ig++)
597610
{
598-
Real gg = (k - q + rhopw->gcar[ig]).norm2() * tpiba2;
611+
const ModuleBase::Vector3<double> g_d = rhopw->gdirect[ig];
612+
const ModuleBase::Vector3<double> kqg_d = k_d - q_d + g_d;
613+
// For gamma_extrapolation (https://doi.org/10.1103/PhysRevB.79.205114)
614+
// 7/8 of the points in the grid are "activated" and 1/8 are disabled.
615+
// grid_factor is designed for the 7/8 of the grid to function like all of the points
616+
Real grid_factor = 1;
617+
double extrapolate_grid = 8.0/7.0;
618+
if (gamma_extrapolation)
619+
{
620+
// if isint(kqg_d[0] * nqs_half1) && isint(kqg_d[1] * nqs_half2) && isint(kqg_d[2] * nqs_half3)
621+
auto isint = [](double x)
622+
{
623+
double epsilon = 1e-6; // this follows the isint judgement in q-e
624+
return std::abs(x - std::round(x)) < epsilon;
625+
};
626+
if (isint(kqg_d[0] * nqs_half1) &&
627+
isint(kqg_d[1] * nqs_half2) &&
628+
isint(kqg_d[2] * nqs_half3))
629+
{
630+
grid_factor = 0;
631+
}
632+
else
633+
{
634+
grid_factor = extrapolate_grid;
635+
}
636+
}
637+
638+
const int ig_kq = ik * nks * npw + iq * npw + ig;
639+
640+
Real gg = (k_c - q_c + rhopw->gcar[ig]).norm2() * tpiba2;
599641
Real hse_omega2 = GlobalC::exx_info.info_global.hse_omega * GlobalC::exx_info.info_global.hse_omega;
600642
// if (kqgcar2 > 1e-12) // vasp uses 1/40 of the smallest (k spacing)**2
601643
if (gg >= 1e-8)
@@ -604,24 +646,29 @@ void OperatorEXXPW<T, Device>::get_potential() const
604646
// if (PARAM.inp.dft_functional == "hse")
605647
if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc)
606648
{
607-
pot[ik * nks * npw + iq * npw + ig] = fac * (1.0 - std::exp(-gg / 4.0 / hse_omega2));
649+
pot[ig_kq] = fac * (1.0 - std::exp(-gg / 4.0 / hse_omega2)) * grid_factor;
650+
}
651+
else if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erf)
652+
{
653+
pot[ig_kq] = fac * (std::exp(-gg / 4.0 / hse_omega2)) * grid_factor;
608654
}
609655
else
610656
{
611-
pot[ik * nks * npw + iq * npw + ig] = fac;
657+
pot[ig_kq] = fac * grid_factor;
612658
}
613659
}
614660
// }
615661
else
616662
{
617663
// if (PARAM.inp.dft_functional == "hse")
618-
if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc)
664+
if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc &&
665+
!gamma_extrapolation)
619666
{
620-
pot[ik * nks * npw + iq * npw + ig] = exx_div - ModuleBase::PI * ModuleBase::e2 / hse_omega2;
667+
pot[ig_kq] = exx_div - ModuleBase::PI * ModuleBase::e2 / hse_omega2;
621668
}
622669
else
623670
{
624-
pot[ik * nks * npw + iq * npw + ig] = exx_div;
671+
pot[ig_kq] = exx_div;
625672
}
626673
}
627674
// assert(is_finite(density_recip[ig]));
@@ -638,6 +685,10 @@ void OperatorEXXPW<T, Device>::exx_divergence()
638685
return;
639686
}
640687

688+
Real nqs_half1 = 0.5 * kv->nmp[0];
689+
Real nqs_half2 = 0.5 * kv->nmp[1];
690+
Real nqs_half3 = 0.5 * kv->nmp[2];
691+
641692
// here we follow the exx_divergence subroutine in q-e (PW/src/exx_base.f90)
642693
double alpha = 10.0 / wfcpw->gk_ecut;
643694
double tpiba2 = tpiba * tpiba;
@@ -647,25 +698,51 @@ void OperatorEXXPW<T, Device>::exx_divergence()
647698
// temporarily for all k points, should be replaced to q points later
648699
for (int ik = 0; ik < wfcpw->nks; ik++)
649700
{
650-
auto k = wfcpw->kvec_c[ik];
701+
const ModuleBase::Vector3<double> k_c = wfcpw->kvec_c[ik];
702+
const ModuleBase::Vector3<double> k_d = wfcpw->kvec_d[ik];
651703
#ifdef _OPENMP
652704
#pragma omp parallel for reduction(+:div)
653705
#endif
654706
for (int ig = 0; ig < rhopw->npw; ig++)
655707
{
656-
auto q = k + rhopw->gcar[ig];
657-
double qq = q.norm2();
708+
const ModuleBase::Vector3<double> q_c = k_c + rhopw->gcar[ig];
709+
const ModuleBase::Vector3<double> q_d = k_d + rhopw->gdirect[ig];
710+
double qq = q_c.norm2();
711+
// For gamma_extrapolation (https://doi.org/10.1103/PhysRevB.79.205114)
712+
// 7/8 of the points in the grid are "activated" and 1/8 are disabled.
713+
// grid_factor is designed for the 7/8 of the grid to function like all of the points
714+
Real grid_factor = 1;
715+
double extrapolate_grid = 8.0/7.0;
716+
if (gamma_extrapolation)
717+
{
718+
auto isint = [](double x)
719+
{
720+
double epsilon = 1e-6; // this follows the isint judgement in q-e
721+
return std::abs(x - std::round(x)) < epsilon;
722+
};
723+
if (isint(q_d[0] * nqs_half1) &&
724+
isint(q_d[1] * nqs_half2) &&
725+
isint(q_d[2] * nqs_half3))
726+
{
727+
grid_factor = 0;
728+
}
729+
else
730+
{
731+
grid_factor = extrapolate_grid;
732+
}
733+
}
734+
658735
if (qq <= 1e-8) continue;
659736
// else if (PARAM.inp.dft_functional == "hse")
660737
else if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc)
661738
{
662739
double omega = GlobalC::exx_info.info_global.hse_omega;
663740
double omega2 = omega * omega;
664-
div += std::exp(-alpha * qq) / qq * (1.0 - std::exp(-qq*tpiba2 / 4.0 / omega2));
741+
div += std::exp(-alpha * qq) / qq * (1.0 - std::exp(-qq*tpiba2 / 4.0 / omega2)) * grid_factor;
665742
}
666743
else
667744
{
668-
div += std::exp(-alpha * qq) / qq;
745+
div += std::exp(-alpha * qq) / qq * grid_factor;
669746
}
670747
}
671748
}
@@ -674,14 +751,18 @@ void OperatorEXXPW<T, Device>::exx_divergence()
674751
// std::cout << "EXX div: " << div << std::endl;
675752

676753
// if (PARAM.inp.dft_functional == "hse")
677-
if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc)
678-
{
679-
double omega = GlobalC::exx_info.info_global.hse_omega;
680-
div += tpiba2 / 4.0 / omega / omega; // compensate for the finite value when qq = 0
681-
}
682-
else
754+
if (!gamma_extrapolation)
683755
{
684-
div -= alpha;
756+
if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc)
757+
{
758+
double omega = GlobalC::exx_info.info_global.hse_omega;
759+
div += tpiba2 / 4.0 / omega / omega; // compensate for the finite value when qq = 0
760+
}
761+
else
762+
{
763+
div -= alpha;
764+
}
765+
685766
}
686767

687768
div *= ModuleBase::e2 * ModuleBase::FOUR_PI / tpiba2 / wfcpw->nks;
@@ -716,7 +797,6 @@ void OperatorEXXPW<T, Device>::exx_divergence()
716797
// std::cout << "EXX divergence: " << exx_div << std::endl;
717798

718799
return;
719-
720800
}
721801

722802
template <typename T, typename Device>
@@ -745,14 +825,14 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_ace(psi::Psi<T, Device> *ppsi_)
745825
setmem_complex_op()(h_psi_ace, 0, psi_.get_nbands() * psi_.get_nbasis());
746826
*ik_ = i;
747827
psi_.fix_kb(i, 0);
748-
auto psi_i = psi_.get_pointer();
828+
T* psi_i = psi_.get_pointer();
749829
act_op_ace(psi_.get_nbands(), psi_.get_nbasis(), 1, psi_i, h_psi_ace, 0, true);
750830

751831
for (int nband = 0; nband < psi_.get_nbands(); nband++)
752832
{
753833
psi_.fix_kb(i, nband);
754-
auto psi_i_n = psi_.get_pointer();
755-
auto hpsi_i_n = h_psi_ace + nband * psi_.get_nbasis();
834+
T* psi_i_n = psi_.get_pointer();
835+
T* hpsi_i_n = h_psi_ace + nband * psi_.get_nbasis();
756836
double wg_i_n = (*wg)(i, nband);
757837
// Eexx += dot(psi_i_n, h_psi_i_n)
758838
Eexx += dot_op()(psi_.get_nbasis(), psi_i_n, hpsi_i_n, false) * wg_i_n * 2;
@@ -881,6 +961,13 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c
881961
Parallel_Reduce::reduce_pool(Eexx_ik_real);
882962
// std::cout << "omega = " << this_->pelec->omega << " tpiba = " << this_->pw_rho->tpiba2 << " exx_div = " << exx_div << std::endl;
883963

964+
delete[] psi_nk_real;
965+
delete[] psi_mq_real;
966+
delete[] h_psi_recip;
967+
delete[] h_psi_real;
968+
delete[] density_real;
969+
delete[] density_recip;
970+
884971
double Eexx = Eexx_ik_real;
885972
return Eexx;
886973
}

source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ class OperatorEXXPW : public OperatorPW<T, Device>
143143
using gemm_complex_op = ModuleBase::gemm_op<T, Device>;
144144
using vec_add_vec_complex_op = ModuleBase::vector_add_vector_op<T, Device>;
145145
using dot_op = ModuleBase::dot_real_op<T, Device>;
146+
147+
bool gamma_extrapolation = true;
148+
146149
};
147150

148151
} // namespace hamilt

0 commit comments

Comments
 (0)