1- #include " module_esolver/esolver_ks_pw .h"
1+ #include " exx_helper .h"
22
33template <typename T, typename Device>
4- double ModuleESolver::ESolver_KS_PW <T, Device>::Exx_Helper:: cal_exx_energy(psi::Psi<T, Device>& psi, ESolver_KS_PW<T, Device>* this_ )
4+ double Exx_Helper <T, Device>::cal_exx_energy(const Device *ctx, psi::Psi<T, Device>& psi, ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, UnitCell* ucell, K_Vectors *kv )
55{
66 ModuleBase::timer::tick (" ESolver_KS_PW" , " cal_exx_energy" );
77
88 using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
99 using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
10- T* psi_nk_real = new T[this_-> pw_wfc ->nrxx ];
11- T* psi_mq_real = new T[this_-> pw_wfc ->nrxx ];
12- T* h_psi_recip = new T[this_-> pw_wfc ->npwk_max ];
13- T* h_psi_real = new T[this_-> pw_wfc ->nrxx ];
14- T* density_real = new T[this_-> pw_wfc ->nrxx ];
15- auto rhopw = this_-> pw_rho ;
10+ T* psi_nk_real = new T[pw_wfc->nrxx ];
11+ T* psi_mq_real = new T[pw_wfc->nrxx ];
12+ T* h_psi_recip = new T[pw_wfc->npwk_max ];
13+ T* h_psi_real = new T[pw_wfc->nrxx ];
14+ T* density_real = new T[pw_wfc->nrxx ];
15+ auto rhopw = pw_rho;
1616 T* density_recip = new T[rhopw->npw ];
17- auto *kv = &this_->kv ;
1817
1918 // lambda
2019 auto exx_divergence = [&]() -> double
2120 {
22- auto wfcpw = this_-> pw_wfc ;
21+ auto wfcpw = pw_wfc;
2322 // if (GlobalC::exx_info.info_lip.lambda == 0.0)
2423 // {
2524 // return 0;
@@ -28,7 +27,7 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
2827 // here we follow the exx_divergence subroutine in q-e (PW/src/exx_base.f90)
2928 // double alpha = GlobalC::exx_info.info_lip.lambda;
3029 double alpha = 10.0 / wfcpw->gk_ecut ;
31- double tpiba2 = this_-> pw_rhod ->tpiba2 ;
30+ double tpiba2 = ucell ->tpiba2 ;
3231 double div = 0 ;
3332
3433 // this is the \sum_q F(q) part
@@ -96,7 +95,7 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
9695 aa *= 8 / ModuleBase::FOUR_PI;
9796 aa += 1.0 / std::sqrt (alpha * ModuleBase::PI);
9897
99- double omega = this_-> pelec ->omega ;
98+ double omega = ucell ->omega ;
10099 div -= ModuleBase::e2 * omega * aa;
101100 return div * wfcpw->nks ;
102101
@@ -111,14 +110,14 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
111110 if (wf_wg == nullptr ) return 0.0 ;
112111 // evaluate the Eexx
113112 // T Eexx_ik = 0.0;
114- Real Eexx_ik_real = 0.0 ;
115- for (int ik = 0 ; ik < this_-> pw_wfc ->nks ; ik++)
113+ double Eexx_ik_real = 0.0 ;
114+ for (int ik = 0 ; ik < pw_wfc->nks ; ik++)
116115 {
117116 // auto k = this->pw_wfc->kvec_c[ik];
118117 // std::cout << k << std::endl;
119118 for (int n_iband = 0 ; n_iband < psi.get_nbands (); n_iband++)
120119 {
121- setmem_complex_op ()(h_psi_recip, 0 , this_-> pw_wfc ->npwk_max );
120+ setmem_complex_op ()(h_psi_recip, 0 , pw_wfc->npwk_max );
122121 setmem_complex_op ()(h_psi_real, 0 , rhopw->nrxx );
123122 setmem_complex_op ()(density_real, 0 , rhopw->nrxx );
124123 setmem_complex_op ()(density_recip, 0 , rhopw->npw );
@@ -137,16 +136,16 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
137136 psi.fix_kb (ik, n_iband);
138137 const T* psi_nk = psi.get_pointer ();
139138 // retrieve \psi_nk in real space
140- this_-> pw_wfc ->recip_to_real (this_-> ctx , psi_nk, psi_nk_real, ik);
139+ pw_wfc->recip_to_real (ctx, psi_nk, psi_nk_real, ik);
141140
142141 // for \psi_nk, get the pw of iq and band m
143142 // q_points is a vector of integers, 0 to nks-1
144143 std::vector<int > q_points;
145- for (int iq = 0 ; iq < this_-> pw_wfc ->nks ; iq++)
144+ for (int iq = 0 ; iq < pw_wfc->nks ; iq++)
146145 {
147146 q_points.push_back (iq);
148147 }
149- Real nqs = q_points.size ();
148+ double nqs = q_points.size ();
150149
151150 // std::cout << "ik = " << ik << " ib = " << n_iband << " wg_kb = " << wg_ikb_real << " wk_ik = " << kv->wk[ik] << std::endl;
152151 for (int iq: q_points)
@@ -168,15 +167,15 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
168167 psi.fix_kb (iq, m_iband);
169168 const T* psi_mq = psi.get_pointer ();
170169 // const T* psi_mq = get_pw(m_iband, iq);
171- this_-> pw_wfc ->recip_to_real (this_-> ctx , psi_mq, psi_mq_real, iq);
170+ pw_wfc->recip_to_real (ctx, psi_mq, psi_mq_real, iq);
172171
173- Real omega_inv = 1.0 / this_-> pelec ->omega ;
172+ T omega_inv = 1.0 / ucell ->omega ;
174173
175174 // direct multiplication in real space, \psi_nk(r) * \psi_mq(r)
176175 #ifdef _OPENMP
177176 #pragma omp parallel for
178177 #endif
179- for (int ir = 0 ; ir < this_-> pw_wfc ->nrxx ; ir++)
178+ for (int ir = 0 ; ir < pw_wfc->nrxx ; ir++)
180179 {
181180 // assert(is_finite(psi_nk_real[ir]));
182181 // assert(is_finite(psi_mq_real[ir]));
@@ -187,17 +186,17 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
187186 // bring the density to recip space
188187 rhopw->real2recip (density_real, density_recip);
189188
190- Real tpiba2 = this_-> pw_rho ->tpiba2 ;
189+ double tpiba2 = pw_rho->tpiba2 ;
191190 // std::cout << tpiba2 << std::endl;
192- Real hse_omega2 = GlobalC::exx_info.info_global .hse_omega * GlobalC::exx_info.info_global .hse_omega ;
191+ double hse_omega2 = GlobalC::exx_info.info_global .hse_omega * GlobalC::exx_info.info_global .hse_omega ;
193192
194193 #ifdef _OPENMP
195194 #pragma omp parallel for reduction(+:Eexx_ik_real) reduction(min:min_gg) reduction(max:max_gg)
196195 #endif
197196 for (int ig = 0 ; ig < rhopw->npw ; ig++)
198197 {
199- auto k = this_-> pw_wfc ->kvec_c [ik];// * latvec;
200- auto q = this_-> pw_wfc ->kvec_c [iq];// * latvec;
198+ auto k = pw_wfc->kvec_c [ik];// * latvec;
199+ auto q = pw_wfc->kvec_c [iq];// * latvec;
201200 auto gcar = rhopw->gcar [ig];
202201 double gg = (k - q + gcar).norm2 () * tpiba2;
203202
@@ -236,18 +235,18 @@ double ModuleESolver::ESolver_KS_PW<T, Device>::Exx_Helper::cal_exx_energy(psi::
236235 } // n_iband
237236
238237 } // ik
239- Eexx_ik_real *= 0.5 * this_-> pelec ->omega ;
238+ Eexx_ik_real *= 0.5 * ucell ->omega ;
240239 Parallel_Reduce::reduce_pool (Eexx_ik_real);
241240// std::cout << "omega = " << this_->pelec->omega << " tpiba = " << this_->pw_rho->tpiba2 << " exx_div = " << exx_div << std::endl;
242241
243- Real Eexx = Eexx_ik_real;
242+ double Eexx = Eexx_ik_real;
244243 ModuleBase::timer::tick (" ESolver_KS_PW" , " cal_exx_energy" );
245244 return Eexx;
246245}
247246
248- template class ModuleESolver ::ESolver_KS_PW <std::complex <float >, base_device::DEVICE_CPU>;
249- template class ModuleESolver ::ESolver_KS_PW <std::complex <double >, base_device::DEVICE_CPU>;
247+ template class Exx_Helper <std::complex <float >, base_device::DEVICE_CPU>;
248+ template class Exx_Helper <std::complex <double >, base_device::DEVICE_CPU>;
250249#if ((defined __CUDA) || (defined __ROCM))
251- template class ModuleESolver ::ESolver_KS_PW <std::complex <float >, base_device::DEVICE_GPU>;
252- template class ModuleESolver ::ESolver_KS_PW <std::complex <double >, base_device::DEVICE_GPU>;
250+ template class Exx_Helper <std::complex <float >, base_device::DEVICE_GPU>;
251+ template class Exx_Helper <std::complex <double >, base_device::DEVICE_GPU>;
253252#endif
0 commit comments