@@ -53,7 +53,11 @@ OperatorEXXPW<T, Device>::OperatorEXXPW(const int* isk_in,
5353 const UnitCell *ucell)
5454 : isk(isk_in), wfcpw(wfcpw_in), rhopw(rhopw_in), kv(kv_in), ucell(ucell)
5555{
56-
56+ gamma_extrapolation = PARAM.inp .exx_gamma_extrapolation ;
57+ if (!kv_in->get_is_mp ())
58+ {
59+ gamma_extrapolation = false ;
60+ }
5761 if (GlobalV::KPAR != 1 )
5862 {
5963 // GlobalV::ofs_running << "EXX Calculation does not support k-point parallelism" << std::endl;
@@ -269,7 +273,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
269273 setmem_complex_op ()(psi_mq_real, 0 , wfcpw->nrxx );
270274
271275 } // end of iq
272- auto h_psi_nk = tmhpsi + n_iband * nbasis;
276+ T* h_psi_nk = tmhpsi + n_iband * nbasis;
273277 Real hybrid_alpha = GlobalC::exx_info.info_global .hybrid_alpha ;
274278 wfcpw->real_to_recip (ctx, h_psi_real, h_psi_nk, this ->ik , true , hybrid_alpha);
275279 setmem_complex_op ()(h_psi_real, 0 , rhopw->nrxx );
@@ -293,7 +297,7 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
293297
294298// std::cout << "act_op_ace" << std::endl;
295299 // hpsi += -Xi^\dagger * Xi * psi
296- auto Xi_ace = Xi_ace_k[this ->ik ];
300+ T* Xi_ace = Xi_ace_k[this ->ik ];
297301 int nbands_tot = psi.get_nbands ();
298302 int nbasis_max = psi.get_nbasis ();
299303// T* hpsi = nullptr;
@@ -344,13 +348,14 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
344348// vec_add_vec_complex_op()(this->ctx, nbands * nbasis, tmhpsi, hpsi, -1, tmhpsi, 1);
345349// delmem_complex_op()(hpsi);
346350 delmem_complex_op ()(Xi_psi);
347- ModuleBase::timer::tick (" OperatorEXXPW" , " act_op " );
351+ ModuleBase::timer::tick (" OperatorEXXPW" , " act_op_ace " );
348352
349353}
350354
351355template <typename T, typename Device>
352356void OperatorEXXPW<T, Device>::construct_ace() const
353357{
358+ ModuleBase::timer::tick (" OperatorEXXPW" , " construct_ace" );
354359// int nkb = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk();
355360 int nbands = psi.get_nbands ();
356361 int nbasis = psi.get_nbasis ();
@@ -494,6 +499,7 @@ void OperatorEXXPW<T, Device>::construct_ace() const
494499 }
495500
496501 *ik_ = ik_save;
502+ ModuleBase::timer::tick (" OperatorEXXPW" , " construct_ace" );
497503
498504}
499505
@@ -539,7 +545,8 @@ void OperatorEXXPW<T, Device>::multiply_potential(T *density_recip, int ik, int
539545 #endif
540546 for (int ig = 0 ; ig < npw; ig++)
541547 {
542- density_recip[ig] *= pot[ik * nks * npw + iq * npw + ig];
548+ int ig_kq = ik * nks * npw + iq * npw + ig;
549+ density_recip[ig] *= pot[ig_kq];
543550
544551 }
545552
@@ -551,7 +558,7 @@ const T *OperatorEXXPW<T, Device>::get_pw(const int m, const int iq) const
551558{
552559 // return pws[iq].get() + m * wfcpw->npwk[iq];
553560 psi.fix_kb (iq, m);
554- auto psi_mq = psi.get_pointer ();
561+ T* psi_mq = psi.get_pointer ();
555562 return psi_mq;
556563}
557564
@@ -580,22 +587,57 @@ OperatorEXXPW<T, Device>::OperatorEXXPW(const OperatorEXXPW<T_in, Device_in> *op
580587template <typename T, typename Device>
581588void OperatorEXXPW<T, Device>::get_potential() const
582589{
590+ Real nqs_half1 = 0.5 * kv->nmp [0 ];
591+ Real nqs_half2 = 0.5 * kv->nmp [1 ];
592+ Real nqs_half3 = 0.5 * kv->nmp [2 ];
593+
583594 int nks = wfcpw->nks , npw = rhopw->npw ;
584595 double tpiba2 = tpiba * tpiba;
585596 // calculate the pot
586597 for (int ik = 0 ; ik < nks; ik++)
587598 {
588599 for (int iq = 0 ; iq < nks; iq++)
589600 {
590- auto k = wfcpw->kvec_c [ik];
591- auto q = wfcpw->kvec_c [iq];
601+ const ModuleBase::Vector3<double > k_c = wfcpw->kvec_c [ik];
602+ const ModuleBase::Vector3<double > k_d = wfcpw->kvec_d [ik];
603+ const ModuleBase::Vector3<double > q_c = wfcpw->kvec_c [iq];
604+ const ModuleBase::Vector3<double > q_d = wfcpw->kvec_d [iq];
592605
593606 #ifdef _OPENMP
594607 #pragma omp parallel for schedule(static)
595608 #endif
596609 for (int ig = 0 ; ig < rhopw->npw ; ig++)
597610 {
598- Real gg = (k - q + rhopw->gcar [ig]).norm2 () * tpiba2;
611+ const ModuleBase::Vector3<double > g_d = rhopw->gdirect [ig];
612+ const ModuleBase::Vector3<double > kqg_d = k_d - q_d + g_d;
613+ // For gamma_extrapolation (https://doi.org/10.1103/PhysRevB.79.205114)
614+ // 7/8 of the points in the grid are "activated" and 1/8 are disabled.
615+ // grid_factor is designed for the 7/8 of the grid to function like all of the points
616+ Real grid_factor = 1 ;
617+ double extrapolate_grid = 8.0 /7.0 ;
618+ if (gamma_extrapolation)
619+ {
620+ // if isint(kqg_d[0] * nqs_half1) && isint(kqg_d[1] * nqs_half2) && isint(kqg_d[2] * nqs_half3)
621+ auto isint = [](double x)
622+ {
623+ double epsilon = 1e-6 ; // this follows the isint judgement in q-e
624+ return std::abs (x - std::round (x)) < epsilon;
625+ };
626+ if (isint (kqg_d[0 ] * nqs_half1) &&
627+ isint (kqg_d[1 ] * nqs_half2) &&
628+ isint (kqg_d[2 ] * nqs_half3))
629+ {
630+ grid_factor = 0 ;
631+ }
632+ else
633+ {
634+ grid_factor = extrapolate_grid;
635+ }
636+ }
637+
638+ const int ig_kq = ik * nks * npw + iq * npw + ig;
639+
640+ Real gg = (k_c - q_c + rhopw->gcar [ig]).norm2 () * tpiba2;
599641 Real hse_omega2 = GlobalC::exx_info.info_global .hse_omega * GlobalC::exx_info.info_global .hse_omega ;
600642 // if (kqgcar2 > 1e-12) // vasp uses 1/40 of the smallest (k spacing)**2
601643 if (gg >= 1e-8 )
@@ -604,24 +646,29 @@ void OperatorEXXPW<T, Device>::get_potential() const
604646 // if (PARAM.inp.dft_functional == "hse")
605647 if (GlobalC::exx_info.info_global .ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc)
606648 {
607- pot[ik * nks * npw + iq * npw + ig] = fac * (1.0 - std::exp (-gg / 4.0 / hse_omega2));
649+ pot[ig_kq] = fac * (1.0 - std::exp (-gg / 4.0 / hse_omega2)) * grid_factor;
650+ }
651+ else if (GlobalC::exx_info.info_global .ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erf)
652+ {
653+ pot[ig_kq] = fac * (std::exp (-gg / 4.0 / hse_omega2)) * grid_factor;
608654 }
609655 else
610656 {
611- pot[ik * nks * npw + iq * npw + ig ] = fac;
657+ pot[ig_kq ] = fac * grid_factor ;
612658 }
613659 }
614660 // }
615661 else
616662 {
617663 // if (PARAM.inp.dft_functional == "hse")
618- if (GlobalC::exx_info.info_global .ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc)
664+ if (GlobalC::exx_info.info_global .ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc &&
665+ !gamma_extrapolation)
619666 {
620- pot[ik * nks * npw + iq * npw + ig ] = exx_div - ModuleBase::PI * ModuleBase::e2 / hse_omega2;
667+ pot[ig_kq ] = exx_div - ModuleBase::PI * ModuleBase::e2 / hse_omega2;
621668 }
622669 else
623670 {
624- pot[ik * nks * npw + iq * npw + ig ] = exx_div;
671+ pot[ig_kq ] = exx_div;
625672 }
626673 }
627674 // assert(is_finite(density_recip[ig]));
@@ -638,6 +685,10 @@ void OperatorEXXPW<T, Device>::exx_divergence()
638685 return ;
639686 }
640687
688+ Real nqs_half1 = 0.5 * kv->nmp [0 ];
689+ Real nqs_half2 = 0.5 * kv->nmp [1 ];
690+ Real nqs_half3 = 0.5 * kv->nmp [2 ];
691+
641692 // here we follow the exx_divergence subroutine in q-e (PW/src/exx_base.f90)
642693 double alpha = 10.0 / wfcpw->gk_ecut ;
643694 double tpiba2 = tpiba * tpiba;
@@ -647,25 +698,51 @@ void OperatorEXXPW<T, Device>::exx_divergence()
647698 // temporarily for all k points, should be replaced to q points later
648699 for (int ik = 0 ; ik < wfcpw->nks ; ik++)
649700 {
650- auto k = wfcpw->kvec_c [ik];
701+ const ModuleBase::Vector3<double > k_c = wfcpw->kvec_c [ik];
702+ const ModuleBase::Vector3<double > k_d = wfcpw->kvec_d [ik];
651703#ifdef _OPENMP
652704#pragma omp parallel for reduction(+:div)
653705#endif
654706 for (int ig = 0 ; ig < rhopw->npw ; ig++)
655707 {
656- auto q = k + rhopw->gcar [ig];
657- double qq = q.norm2 ();
708+ const ModuleBase::Vector3<double > q_c = k_c + rhopw->gcar [ig];
709+ const ModuleBase::Vector3<double > q_d = k_d + rhopw->gdirect [ig];
710+ double qq = q_c.norm2 ();
711+ // For gamma_extrapolation (https://doi.org/10.1103/PhysRevB.79.205114)
712+ // 7/8 of the points in the grid are "activated" and 1/8 are disabled.
713+ // grid_factor is designed for the 7/8 of the grid to function like all of the points
714+ Real grid_factor = 1 ;
715+ double extrapolate_grid = 8.0 /7.0 ;
716+ if (gamma_extrapolation)
717+ {
718+ auto isint = [](double x)
719+ {
720+ double epsilon = 1e-6 ; // this follows the isint judgement in q-e
721+ return std::abs (x - std::round (x)) < epsilon;
722+ };
723+ if (isint (q_d[0 ] * nqs_half1) &&
724+ isint (q_d[1 ] * nqs_half2) &&
725+ isint (q_d[2 ] * nqs_half3))
726+ {
727+ grid_factor = 0 ;
728+ }
729+ else
730+ {
731+ grid_factor = extrapolate_grid;
732+ }
733+ }
734+
658735 if (qq <= 1e-8 ) continue ;
659736 // else if (PARAM.inp.dft_functional == "hse")
660737 else if (GlobalC::exx_info.info_global .ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc)
661738 {
662739 double omega = GlobalC::exx_info.info_global .hse_omega ;
663740 double omega2 = omega * omega;
664- div += std::exp (-alpha * qq) / qq * (1.0 - std::exp (-qq*tpiba2 / 4.0 / omega2));
741+ div += std::exp (-alpha * qq) / qq * (1.0 - std::exp (-qq*tpiba2 / 4.0 / omega2)) * grid_factor ;
665742 }
666743 else
667744 {
668- div += std::exp (-alpha * qq) / qq;
745+ div += std::exp (-alpha * qq) / qq * grid_factor ;
669746 }
670747 }
671748 }
@@ -674,14 +751,18 @@ void OperatorEXXPW<T, Device>::exx_divergence()
674751 // std::cout << "EXX div: " << div << std::endl;
675752
676753 // if (PARAM.inp.dft_functional == "hse")
677- if (GlobalC::exx_info.info_global .ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc)
678- {
679- double omega = GlobalC::exx_info.info_global .hse_omega ;
680- div += tpiba2 / 4.0 / omega / omega; // compensate for the finite value when qq = 0
681- }
682- else
754+ if (!gamma_extrapolation)
683755 {
684- div -= alpha;
756+ if (GlobalC::exx_info.info_global .ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc)
757+ {
758+ double omega = GlobalC::exx_info.info_global .hse_omega ;
759+ div += tpiba2 / 4.0 / omega / omega; // compensate for the finite value when qq = 0
760+ }
761+ else
762+ {
763+ div -= alpha;
764+ }
765+
685766 }
686767
687768 div *= ModuleBase::e2 * ModuleBase::FOUR_PI / tpiba2 / wfcpw->nks ;
@@ -716,7 +797,6 @@ void OperatorEXXPW<T, Device>::exx_divergence()
716797 // std::cout << "EXX divergence: " << exx_div << std::endl;
717798
718799 return ;
719-
720800}
721801
722802template <typename T, typename Device>
@@ -745,14 +825,14 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_ace(psi::Psi<T, Device> *ppsi_)
745825 setmem_complex_op ()(h_psi_ace, 0 , psi_.get_nbands () * psi_.get_nbasis ());
746826 *ik_ = i;
747827 psi_.fix_kb (i, 0 );
748- auto psi_i = psi_.get_pointer ();
828+ T* psi_i = psi_.get_pointer ();
749829 act_op_ace (psi_.get_nbands (), psi_.get_nbasis (), 1 , psi_i, h_psi_ace, 0 , true );
750830
751831 for (int nband = 0 ; nband < psi_.get_nbands (); nband++)
752832 {
753833 psi_.fix_kb (i, nband);
754- auto psi_i_n = psi_.get_pointer ();
755- auto hpsi_i_n = h_psi_ace + nband * psi_.get_nbasis ();
834+ T* psi_i_n = psi_.get_pointer ();
835+ T* hpsi_i_n = h_psi_ace + nband * psi_.get_nbasis ();
756836 double wg_i_n = (*wg)(i, nband);
757837 // Eexx += dot(psi_i_n, h_psi_i_n)
758838 Eexx += dot_op ()(psi_.get_nbasis (), psi_i_n, hpsi_i_n, false ) * wg_i_n * 2 ;
@@ -881,6 +961,13 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c
881961 Parallel_Reduce::reduce_pool (Eexx_ik_real);
882962 // std::cout << "omega = " << this_->pelec->omega << " tpiba = " << this_->pw_rho->tpiba2 << " exx_div = " << exx_div << std::endl;
883963
964+ delete[] psi_nk_real;
965+ delete[] psi_mq_real;
966+ delete[] h_psi_recip;
967+ delete[] h_psi_real;
968+ delete[] density_real;
969+ delete[] density_recip;
970+
884971 double Eexx = Eexx_ik_real;
885972 return Eexx;
886973}
0 commit comments