Skip to content

Commit 155f47a

Browse files
Reconstruct: move exx_helper to hamilt_pwdft
1 parent 36b7f71 commit 155f47a

File tree

9 files changed

+115
-116
lines changed

9 files changed

+115
-116
lines changed

source/module_esolver/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ add_library(
2727
esolver
2828
OBJECT
2929
${objects}
30-
module_exx_helper/exx_helper.cpp
30+
../module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.cpp
31+
../module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.h
3132
)
3233

3334
if(ENABLE_COVERAGE)

source/module_esolver/esolver_ks.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -661,22 +661,6 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
661661
this->pelec->cal_energies(1);
662662
this->pelec->cal_energies(2);
663663

664-
// // for separate loop in hybrid functionals in pw
665-
// // this gives correct energy for the first "pure" functional loop
666-
// // and update the functional for the next loop
667-
// if (PARAM.inp.basis_type == "pw")
668-
// {
669-
// auto p_esolver_ks_pw = dynamic_cast<ESolver_KS_PW<T, Device>*>(this);
670-
//
671-
// if (conv_esolver && GlobalC::exx_info.info_global.cal_exx && GlobalC::exx_info.info_global.separate_loop && p_esolver_ks_pw->exx_helper.first_iter)
672-
// {
673-
// conv_esolver = false;
674-
// XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func);
675-
// this->update_pot(ucell, istep, iter, conv_esolver);
676-
// conv_esolver = true;
677-
// }
678-
// }
679-
680664
if (iter == 1)
681665
{
682666
this->pelec->f_en.etot_old = this->pelec->f_en.etot;

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -513,18 +513,6 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(UnitCell& ucell,
513513
srho.begin(is, this->chr, this->pw_rhod, ucell.symm);
514514
}
515515

516-
#ifdef __EXX
517-
if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.first_iter)
518-
{
519-
this->pelec->set_exx(exx_helper.cal_exx_energy(this->kspw_psi[0], this));
520-
}
521-
#endif
522-
523-
// deband is calculated from "output" charge density calculated
524-
// in sum_band
525-
// need 'rho(out)' and 'vr (v_h(in) and v_xc(in))'
526-
this->pelec->f_en.deband = this->pelec->cal_delta_eband(ucell);
527-
528516
ModuleBase::timer::tick("ESolver_KS_PW", "hamilt2density_single");
529517
}
530518

@@ -550,6 +538,18 @@ void ESolver_KS_PW<T, Device>::update_pot(UnitCell& ucell, const int istep, cons
550538
template <typename T, typename Device>
551539
void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver)
552540
{
541+
#ifdef __EXX
542+
if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.first_iter)
543+
{
544+
this->pelec->set_exx(exx_helper.cal_exx_energy(this->ctx, this->kspw_psi[0], this->pw_wfc, this->pw_rho, &ucell, &this->kv));
545+
}
546+
#endif
547+
548+
// deband is calculated from "output" charge density calculated
549+
// in sum_band
550+
// need 'rho(out)' and 'vr (v_h(in) and v_xc(in))'
551+
this->pelec->f_en.deband = this->pelec->cal_delta_eband(ucell);
552+
553553
// 1) Call iter_finish() of ESolver_KS
554554
ESolver_KS<T, Device>::iter_finish(ucell, istep, iter, conv_esolver);
555555

source/module_esolver/esolver_ks_pw.h

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "./esolver_ks.h"
44
#include "module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.h"
55
#include "module_psi/psi_init.h"
6-
6+
#include "module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.h"
77
#include "module_hamilt_pw/hamilt_pwdft/global.h"
88

99
#include <memory>
@@ -33,58 +33,10 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>
3333

3434
void after_all_runners(UnitCell& ucell) override;
3535

36-
#ifdef __EXX
37-
struct Exx_Helper
38-
{
39-
public:
40-
Exx_Helper() = default;
41-
ModuleBase::matrix * wf_wg;
42-
psi::Psi<T, base_device::DEVICE_CPU> psi;
43-
static constexpr double DIV_UNDEFINED = 0x0d000721;
44-
double div = DIV_UNDEFINED;
45-
bool construct_ace = false;
46-
47-
bool exx_after_converge(int &iter)
48-
{
49-
if (first_iter)
50-
{
51-
first_iter = false;
52-
}
53-
else if (!GlobalC::exx_info.info_global.separate_loop)
54-
{
55-
return true;
56-
}
57-
else if (iter == 1)
58-
{
59-
return true;
60-
}
61-
GlobalV::ofs_running << "Updating EXX and rerun SCF" << std::endl;
62-
iter = 0;
63-
return false;
64-
65-
}
66-
67-
void set_psi(psi::Psi<T, Device> &psi_)
68-
{
69-
this->psi = psi_;
70-
construct_ace = true;
71-
}
72-
73-
void reset_div()
74-
{
75-
this->div = DIV_UNDEFINED;
76-
}
77-
78-
double cal_exx_energy(psi::Psi<T, Device> &psi, ESolver_KS_PW<T, Device> *this_);
79-
80-
bool first_iter = false;
81-
};
82-
#endif
83-
8436
// EXX Todo: verify current implementation for after_converge
8537
// virtual bool do_after_converge(int &iter) override;
8638
#ifdef __EXX
87-
Exx_Helper exx_helper;
39+
Exx_Helper<T, Device> exx_helper;
8840
#endif
8941

9042

source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
397397
}
398398

399399
template<typename T, typename Device>
400-
void HamiltPW<T, Device>::set_exx_helper(typename ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper *exx_helper)
400+
void HamiltPW<T, Device>::set_exx_helper(Exx_Helper<T, Device> *exx_helper)
401401
{
402402
auto op = this->ops;
403403
while (op != nullptr)

source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "module_hamilt_general/hamilt.h"
99
#include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h"
1010
#include "module_base/kernels/math_kernel_op.h"
11+
#include "module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.h"
1112

1213
namespace hamilt
1314
{
@@ -37,8 +38,8 @@ class HamiltPW : public Hamilt<T, Device>
3738
) const override;
3839

3940
#ifdef __EXX
40-
void set_exx_helper(typename ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper* exx_helper_in);
41-
typename ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper* p_exx_helper;
41+
void set_exx_helper(Exx_Helper<T, Device>* exx_helper_in);
42+
Exx_Helper<T, Device>* p_exx_helper;
4243
#endif
4344

4445
protected:

source/module_esolver/module_exx_helper/exx_helper.cpp renamed to source/module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.cpp

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
1-
#include "module_esolver/esolver_ks_pw.h"
1+
#include "exx_helper.h"
22

33
template <typename T, typename Device>
4-
double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::Psi<T, Device>& psi, ESolver_KS_PW<T, Device>* this_)
4+
double Exx_Helper<T, Device>::cal_exx_energy(const Device *ctx, psi::Psi<T, Device>& psi, ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, UnitCell* ucell, K_Vectors *kv)
55
{
66
ModuleBase::timer::tick("ESolver_KS_PW", "cal_exx_energy");
77

88
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
99
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
10-
T* psi_nk_real = new T[this_->pw_wfc->nrxx];
11-
T* psi_mq_real = new T[this_->pw_wfc->nrxx];
12-
T* h_psi_recip = new T[this_->pw_wfc->npwk_max];
13-
T* h_psi_real = new T[this_->pw_wfc->nrxx];
14-
T* density_real = new T[this_->pw_wfc->nrxx];
15-
auto rhopw = this_->pw_rho;
10+
T* psi_nk_real = new T[pw_wfc->nrxx];
11+
T* psi_mq_real = new T[pw_wfc->nrxx];
12+
T* h_psi_recip = new T[pw_wfc->npwk_max];
13+
T* h_psi_real = new T[pw_wfc->nrxx];
14+
T* density_real = new T[pw_wfc->nrxx];
15+
auto rhopw = pw_rho;
1616
T* density_recip = new T[rhopw->npw];
17-
auto *kv = &this_->kv;
1817

1918
// lambda
2019
auto exx_divergence = [&]() -> double
2120
{
22-
auto wfcpw = this_->pw_wfc;
21+
auto wfcpw = pw_wfc;
2322
// if (GlobalC::exx_info.info_lip.lambda == 0.0)
2423
// {
2524
// return 0;
@@ -28,7 +27,7 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
2827
// here we follow the exx_divergence subroutine in q-e (PW/src/exx_base.f90)
2928
// double alpha = GlobalC::exx_info.info_lip.lambda;
3029
double alpha = 10.0 / wfcpw->gk_ecut;
31-
double tpiba2 = this_->pw_rhod->tpiba2;
30+
double tpiba2 = ucell->tpiba2;
3231
double div = 0;
3332

3433
// this is the \sum_q F(q) part
@@ -96,7 +95,7 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
9695
aa *= 8 / ModuleBase::FOUR_PI;
9796
aa += 1.0 / std::sqrt(alpha * ModuleBase::PI);
9897

99-
double omega = this_->pelec->omega;
98+
double omega = ucell->omega;
10099
div -= ModuleBase::e2 * omega * aa;
101100
return div * wfcpw->nks;
102101

@@ -111,14 +110,14 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
111110
if (wf_wg == nullptr) return 0.0;
112111
// evaluate the Eexx
113112
// T Eexx_ik = 0.0;
114-
Real Eexx_ik_real = 0.0;
115-
for (int ik = 0; ik < this_->pw_wfc->nks; ik++)
113+
double Eexx_ik_real = 0.0;
114+
for (int ik = 0; ik < pw_wfc->nks; ik++)
116115
{
117116
// auto k = this->pw_wfc->kvec_c[ik];
118117
// std::cout << k << std::endl;
119118
for (int n_iband = 0; n_iband < psi.get_nbands(); n_iband++)
120119
{
121-
setmem_complex_op()(h_psi_recip, 0, this_->pw_wfc->npwk_max);
120+
setmem_complex_op()(h_psi_recip, 0, pw_wfc->npwk_max);
122121
setmem_complex_op()(h_psi_real, 0, rhopw->nrxx);
123122
setmem_complex_op()(density_real, 0, rhopw->nrxx);
124123
setmem_complex_op()(density_recip, 0, rhopw->npw);
@@ -137,16 +136,16 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
137136
psi.fix_kb(ik, n_iband);
138137
const T* psi_nk = psi.get_pointer();
139138
// retrieve \psi_nk in real space
140-
this_->pw_wfc->recip_to_real(this_->ctx, psi_nk, psi_nk_real, ik);
139+
pw_wfc->recip_to_real(ctx, psi_nk, psi_nk_real, ik);
141140

142141
// for \psi_nk, get the pw of iq and band m
143142
// q_points is a vector of integers, 0 to nks-1
144143
std::vector<int> q_points;
145-
for (int iq = 0; iq < this_->pw_wfc->nks; iq++)
144+
for (int iq = 0; iq < pw_wfc->nks; iq++)
146145
{
147146
q_points.push_back(iq);
148147
}
149-
Real nqs = q_points.size();
148+
double nqs = q_points.size();
150149

151150
// std::cout << "ik = " << ik << " ib = " << n_iband << " wg_kb = " << wg_ikb_real << " wk_ik = " << kv->wk[ik] << std::endl;
152151
for (int iq: q_points)
@@ -168,15 +167,15 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
168167
psi.fix_kb(iq, m_iband);
169168
const T* psi_mq = psi.get_pointer();
170169
// const T* psi_mq = get_pw(m_iband, iq);
171-
this_->pw_wfc->recip_to_real(this_->ctx, psi_mq, psi_mq_real, iq);
170+
pw_wfc->recip_to_real(ctx, psi_mq, psi_mq_real, iq);
172171

173-
Real omega_inv = 1.0 / this_->pelec->omega;
172+
T omega_inv = 1.0 / ucell->omega;
174173

175174
// direct multiplication in real space, \psi_nk(r) * \psi_mq(r)
176175
#ifdef _OPENMP
177176
#pragma omp parallel for
178177
#endif
179-
for (int ir = 0; ir < this_->pw_wfc->nrxx; ir++)
178+
for (int ir = 0; ir < pw_wfc->nrxx; ir++)
180179
{
181180
// assert(is_finite(psi_nk_real[ir]));
182181
// assert(is_finite(psi_mq_real[ir]));
@@ -187,17 +186,17 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
187186
// bring the density to recip space
188187
rhopw->real2recip(density_real, density_recip);
189188

190-
Real tpiba2 = this_->pw_rho->tpiba2;
189+
double tpiba2 = pw_rho->tpiba2;
191190
// std::cout << tpiba2 << std::endl;
192-
Real hse_omega2 = GlobalC::exx_info.info_global.hse_omega * GlobalC::exx_info.info_global.hse_omega;
191+
double hse_omega2 = GlobalC::exx_info.info_global.hse_omega * GlobalC::exx_info.info_global.hse_omega;
193192

194193
#ifdef _OPENMP
195194
#pragma omp parallel for reduction(+:Eexx_ik_real) reduction(min:min_gg) reduction(max:max_gg)
196195
#endif
197196
for (int ig = 0; ig < rhopw->npw; ig++)
198197
{
199-
auto k = this_->pw_wfc->kvec_c[ik];// * latvec;
200-
auto q = this_->pw_wfc->kvec_c[iq];// * latvec;
198+
auto k = pw_wfc->kvec_c[ik];// * latvec;
199+
auto q = pw_wfc->kvec_c[iq];// * latvec;
201200
auto gcar = rhopw->gcar[ig];
202201
double gg = (k - q + gcar).norm2() * tpiba2;
203202

@@ -236,18 +235,18 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
236235
} // n_iband
237236

238237
} // ik
239-
Eexx_ik_real *= 0.5 * this_->pelec->omega;
238+
Eexx_ik_real *= 0.5 * ucell->omega;
240239
Parallel_Reduce::reduce_pool(Eexx_ik_real);
241240
// std::cout << "omega = " << this_->pelec->omega << " tpiba = " << this_->pw_rho->tpiba2 << " exx_div = " << exx_div << std::endl;
242241

243-
Real Eexx = Eexx_ik_real;
242+
double Eexx = Eexx_ik_real;
244243
ModuleBase::timer::tick("ESolver_KS_PW", "cal_exx_energy");
245244
return Eexx;
246245
}
247246

248-
template class ModuleESolver::ESolver_KS_PW<std::complex<float>, base_device::DEVICE_CPU>;
249-
template class ModuleESolver::ESolver_KS_PW<std::complex<double>, base_device::DEVICE_CPU>;
247+
template class Exx_Helper<std::complex<float>, base_device::DEVICE_CPU>;
248+
template class Exx_Helper<std::complex<double>, base_device::DEVICE_CPU>;
250249
#if ((defined __CUDA) || (defined __ROCM))
251-
template class ModuleESolver::ESolver_KS_PW<std::complex<float>, base_device::DEVICE_GPU>;
252-
template class ModuleESolver::ESolver_KS_PW<std::complex<double>, base_device::DEVICE_GPU>;
250+
template class Exx_Helper<std::complex<float>, base_device::DEVICE_GPU>;
251+
template class Exx_Helper<std::complex<double>, base_device::DEVICE_GPU>;
253252
#endif
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//
2+
// For EXX in PW.
3+
//
4+
#include "module_psi/psi.h"
5+
#include "module_base/matrix.h"
6+
#include "module_hamilt_pw/hamilt_pwdft/global.h"
7+
8+
#ifndef EXX_HELPER_H
9+
#define EXX_HELPER_H
10+
template <typename T, typename Device>
11+
struct Exx_Helper
12+
{
13+
public:
14+
Exx_Helper() = default;
15+
ModuleBase::matrix * wf_wg;
16+
psi::Psi<T, Device> psi;
17+
static constexpr double DIV_UNDEFINED = 0x0d000721;
18+
double div = DIV_UNDEFINED;
19+
bool construct_ace = false;
20+
21+
bool exx_after_converge(int &iter)
22+
{
23+
if (first_iter)
24+
{
25+
first_iter = false;
26+
}
27+
else if (!GlobalC::exx_info.info_global.separate_loop)
28+
{
29+
return true;
30+
}
31+
else if (iter == 1)
32+
{
33+
return true;
34+
}
35+
GlobalV::ofs_running << "Updating EXX and rerun SCF" << std::endl;
36+
iter = 0;
37+
return false;
38+
39+
}
40+
41+
void set_psi(psi::Psi<T, Device> &psi_)
42+
{
43+
this->psi = psi_;
44+
construct_ace = true;
45+
}
46+
47+
void reset_div()
48+
{
49+
this->div = DIV_UNDEFINED;
50+
}
51+
52+
double cal_exx_energy(const Device *ctx,
53+
psi::Psi<T, Device>& psi,
54+
ModulePW::PW_Basis_K* pw_wfc,
55+
ModulePW::PW_Basis* pw_rho,
56+
UnitCell* ucell,
57+
K_Vectors *kv);
58+
59+
60+
bool first_iter = false;
61+
};
62+
#endif // EXX_HELPER_H

0 commit comments

Comments
 (0)