Skip to content

Commit e3193c8

Browse files
authored
Refactor: Remove C++17 syntax and optimize DMR calculation in DeePKS. (#6094)
* Use dm_r in LCAO_Deepks to avoid double counting in DeePKS. * Remove 'if constexpr' usage in DeePKS. * Fix a merge bug.
1 parent 6207798 commit e3193c8

24 files changed

+427
-500
lines changed

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
#include "module_io/write_wfc_nao.h"
3131
#include "module_parameter/parameter.h"
3232

33-
//be careful of hpp, there may be multiple definitions of functions, 20250302, mohan
33+
// be careful of hpp, there may be multiple definitions of functions, 20250302, mohan
34+
#include "module_hamilt_lcao/hamilt_lcaodft/hs_matrix_k.hpp"
3435
#include "module_io/write_eband_terms.hpp"
3536
#include "module_io/write_vxc.hpp"
3637
#include "module_io/write_vxc_r.hpp"
37-
#include "module_hamilt_lcao/hamilt_lcaodft/hs_matrix_k.hpp"
3838

3939
//--------------temporary----------------------------
4040
#include "module_base/global_function.h"
@@ -179,15 +179,13 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
179179

180180
// 7) initialize exact exchange calculations
181181
#ifdef __EXX
182-
if (PARAM.inp.calculation == "scf"
183-
|| PARAM.inp.calculation == "relax"
184-
|| PARAM.inp.calculation == "cell-relax"
182+
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax"
185183
|| PARAM.inp.calculation == "md")
186184
{
187185
if (GlobalC::exx_info.info_global.cal_exx)
188186
{
189187
if (PARAM.inp.init_wfc != "file")
190-
{ // if init_wfc==file, directly enter the EXX loop
188+
{ // if init_wfc==file, directly enter the EXX loop
191189
XC_Functional::set_xc_first_loop(ucell);
192190
}
193191

@@ -307,7 +305,6 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
307305
return;
308306
}
309307

310-
311308
template <typename TK, typename TR>
312309
double ESolver_KS_LCAO<TK, TR>::cal_energy()
313310
{
@@ -316,7 +313,6 @@ double ESolver_KS_LCAO<TK, TR>::cal_energy()
316313
return this->pelec->f_en.etot;
317314
}
318315

319-
320316
template <typename TK, typename TR>
321317
void ESolver_KS_LCAO<TK, TR>::cal_force(UnitCell& ucell, ModuleBase::matrix& force)
322318
{
@@ -460,7 +456,7 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
460456
this->pelec->ekb,
461457
this->kv);
462458
}
463-
}
459+
}
464460

465461
// 4) write projected band structure by jiyy-2022-4-20
466462
if (PARAM.inp.out_proj_band)
@@ -489,49 +485,50 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
489485
if (PARAM.inp.out_mat_xc)
490486
{
491487
ModuleIO::write_Vxc<TK, TR>(PARAM.inp.nspin,
492-
PARAM.globalv.nlocal,
493-
GlobalV::DRANK,
494-
&this->pv,
495-
*this->psi,
496-
ucell,
497-
this->sf,
498-
this->solvent,
499-
*this->pw_rho,
500-
*this->pw_rhod,
501-
this->locpp.vloc,
502-
this->chr,
503-
this->GG,
504-
this->GK,
505-
this->kv,
506-
orb_.cutoffs(),
507-
this->pelec->wg,
508-
this->gd
488+
PARAM.globalv.nlocal,
489+
GlobalV::DRANK,
490+
&this->pv,
491+
*this->psi,
492+
ucell,
493+
this->sf,
494+
this->solvent,
495+
*this->pw_rho,
496+
*this->pw_rhod,
497+
this->locpp.vloc,
498+
this->chr,
499+
this->GG,
500+
this->GK,
501+
this->kv,
502+
orb_.cutoffs(),
503+
this->pelec->wg,
504+
this->gd
509505
#ifdef __EXX
510-
,
511-
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
512-
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
506+
,
507+
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
508+
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
513509
#endif
514510
);
515511
}
516512
if (PARAM.inp.out_mat_xc2)
517513
{
518514
ModuleIO::write_Vxc_R<TK, TR>(PARAM.inp.nspin,
519-
&this->pv,
520-
ucell,
521-
this->sf,
522-
this->solvent,
523-
*this->pw_rho,
524-
*this->pw_rhod,
525-
this->locpp.vloc,
526-
this->chr,
527-
this->GG,
528-
this->GK,
529-
this->kv,
530-
orb_.cutoffs(),
531-
this->gd
515+
&this->pv,
516+
ucell,
517+
this->sf,
518+
this->solvent,
519+
*this->pw_rho,
520+
*this->pw_rhod,
521+
this->locpp.vloc,
522+
this->chr,
523+
this->GG,
524+
this->GK,
525+
this->kv,
526+
orb_.cutoffs(),
527+
this->gd
532528
#ifdef __EXX
533-
, this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
534-
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
529+
,
530+
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
531+
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
535532
#endif
536533
);
537534
}
@@ -569,7 +566,6 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
569566
ModuleBase::timer::tick("ESolver_KS_LCAO", "after_all_runners");
570567
}
571568

572-
573569
template <typename TK, typename TR>
574570
void ESolver_KS_LCAO<TK, TR>::iter_init(UnitCell& ucell, const int istep, const int iter)
575571
{
@@ -639,11 +635,10 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(UnitCell& ucell, const int istep, const
639635
if (GlobalC::exx_info.info_global.cal_exx)
640636
{
641637
// the following steps are only needed in the first outer exx loop
642-
exx_two_level_step = GlobalC::exx_info.info_ri.real_number ?
643-
this->exd->two_level_step
644-
: this->exc->two_level_step;
638+
exx_two_level_step
639+
= GlobalC::exx_info.info_ri.real_number ? this->exd->two_level_step : this->exc->two_level_step;
645640
}
646-
#endif
641+
#endif
647642
if (iter == 1 && exx_two_level_step == 0)
648643
{
649644
std::cout << " WAVEFUN -> CHARGE " << std::endl;
@@ -742,6 +737,11 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(UnitCell& ucell, const int istep, const
742737
{
743738
this->p_hamilt->refresh();
744739
}
740+
if (iter == 1 && istep == 0)
741+
{
742+
// initialize DMR
743+
this->ld.init_DMR(ucell, orb_, this->pv, this->gd);
744+
}
745745
#endif
746746

747747
if (PARAM.inp.vl_in_h)
@@ -758,7 +758,6 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(UnitCell& ucell, const int istep, const
758758
}
759759
}
760760

761-
762761
template <typename TK, typename TR>
763762
void ESolver_KS_LCAO<TK, TR>::hamilt2rho_single(UnitCell& ucell, int istep, int iter, double ethr)
764763
{
@@ -822,7 +821,6 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2rho_single(UnitCell& ucell, int istep, int
822821
this->pelec->f_en.deband = this->pelec->cal_delta_eband(ucell);
823822
}
824823

825-
826824
template <typename TK, typename TR>
827825
void ESolver_KS_LCAO<TK, TR>::update_pot(UnitCell& ucell, const int istep, const int iter, const bool conv_esolver)
828826
{
@@ -840,7 +838,6 @@ void ESolver_KS_LCAO<TK, TR>::update_pot(UnitCell& ucell, const int istep, const
840838
}
841839
}
842840

843-
844841
template <typename TK, typename TR>
845842
void ESolver_KS_LCAO<TK, TR>::iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver)
846843
{
@@ -878,6 +875,7 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(UnitCell& ucell, const int istep, int&
878875
= dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM()->get_DMK_vector();
879876

880877
ld.dpks_cal_e_delta_band(dm, this->kv.get_nks());
878+
DeePKS_domain::update_dmr(this->kv.kvec_d, dm, ucell, orb_, this->pv, this->gd, ld.dm_r);
881879
this->pelec->f_en.edeepks_scf = ld.E_delta - ld.e_delta_band;
882880
this->pelec->f_en.edeepks_delta = ld.E_delta;
883881
}

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ void Force_LCAO<double>::allocate(const UnitCell& ucell,
3737
// save the results in dense matrix by now.
3838
// pv.nloc: number of H elements in this proc.
3939

40-
assert(pv.nloc>0);
40+
assert(pv.nloc > 0);
4141
fsr.DSloc_x = new double[pv.nloc];
4242
fsr.DSloc_y = new double[pv.nloc];
4343
fsr.DSloc_z = new double[pv.nloc];
@@ -230,11 +230,9 @@ void Force_LCAO<double>::ftable(const bool isforce,
230230
#ifdef __DEEPKS
231231
if (PARAM.inp.deepks_scf)
232232
{
233-
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();
234-
235233
// No need to update E_delta here since it have been done in LCAO_Deepks_Interface in after_scf
236234
const int nks = 1;
237-
DeePKS_domain::cal_f_delta<double>(dm_gamma,
235+
DeePKS_domain::cal_f_delta<double>(ld.dm_r,
238236
ucell,
239237
orb,
240238
gd,

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,8 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
268268
#ifdef __DEEPKS
269269
if (PARAM.inp.deepks_scf)
270270
{
271-
const std::vector<std::vector<std::complex<double>>>& dm_k = dm->get_DMK_vector();
272-
273271
// No need to update E_delta since it have been done in LCAO_Deepks_Interface in after_scf
274-
DeePKS_domain::cal_f_delta<std::complex<double>>(dm_k,
272+
DeePKS_domain::cal_f_delta<std::complex<double>>(ld.dm_r,
275273
ucell,
276274
orb,
277275
gd,

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::initialize_HR(const Grid_Driv
8282
for (int iat0 = 0; iat0 < ucell->nat; iat0++)
8383
{
8484
auto tau0 = ucell->get_tau(iat0);
85-
int T0=0;
86-
int I0=0;
85+
int T0 = 0;
86+
int I0 = 0;
8787
ucell->iat2iait(iat0, &I0, &T0);
8888
AdjacentAtomInfo adjs;
8989
GridD->Find_atom(*ucell, tau0, T0, I0, &adjs);
@@ -174,7 +174,7 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
174174
this->ld->inl2l,
175175
this->ld->inl_index,
176176
this->kvec_d,
177-
this->DM,
177+
this->ld->dm_r,
178178
this->ld->phialpha,
179179
*this->ucell,
180180
*ptr_orb_,
@@ -242,8 +242,8 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::pre_calculate_nlm(
242242
const Parallel_Orbitals* paraV = this->hR->get_paraV();
243243
const int npol = this->ucell->get_npol();
244244
auto tau0 = ucell->get_tau(iat0);
245-
int T0=0;
246-
int I0=0;
245+
int T0 = 0;
246+
int I0 = 0;
247247
ucell->iat2iait(iat0, &I0, &T0);
248248
AdjacentAtomInfo& adjs = this->adjs_all[iat0];
249249
nlm_in.resize(adjs.adj_num + 1);
@@ -307,8 +307,8 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::calculate_HR()
307307
for (int iat0 = 0; iat0 < this->ucell->nat; iat0++)
308308
{
309309
auto tau0 = ucell->get_tau(iat0);
310-
int T0=0;
311-
int I0=0;
310+
int T0 = 0;
311+
int I0 = 0;
312312
ucell->iat2iait(iat0, &I0, &T0);
313313
AdjacentAtomInfo& adjs = this->adjs_all[iat0];
314314

@@ -370,8 +370,8 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::calculate_HR()
370370
this->pre_calculate_nlm(iat0, nlm_on_the_fly);
371371
}
372372

373-
std::vector<std::unordered_map<int, std::vector<double>>>& nlm_iat =
374-
is_on_the_fly ? nlm_on_the_fly : nlm_tot[iat0];
373+
std::vector<std::unordered_map<int, std::vector<double>>>& nlm_iat
374+
= is_on_the_fly ? nlm_on_the_fly : nlm_tot[iat0];
375375

376376
// 2. calculate <phi_I|beta>D<beta|phi_{J,R}> for each pair of <IJR> atoms
377377
for (int ad1 = 0; ad1 < adjs.adj_num + 1; ++ad1)
@@ -500,7 +500,6 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::cal_HR_IJR(const double* hr_i
500500
}
501501
}
502502

503-
504503
// contributeHk()
505504
template <typename TK, typename TR>
506505
void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHk(int ik)

source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifdef __DEEPKS
1414

1515
#include "LCAO_deepks.h"
16+
#include "deepks_iterate.h"
1617
#include "module_hamilt_pw/hamilt_pwdft/global.h"
1718

1819
// Constructor of the class
@@ -206,10 +207,58 @@ void LCAO_Deepks<T>::allocate_V_delta(const int nat, const int nks)
206207
return;
207208
}
208209

210+
template <typename T>
211+
void LCAO_Deepks<T>::init_DMR(const UnitCell& ucell,
212+
const LCAO_Orbitals& orb,
213+
const Parallel_Orbitals& pv,
214+
const Grid_Driver& GridD)
215+
{
216+
this->dm_r = new hamilt::HContainer<double>(&pv);
217+
DeePKS_domain::iterate_ad2(
218+
ucell,
219+
GridD,
220+
orb,
221+
false, // no trace_alpha
222+
[&](const int iat,
223+
const ModuleBase::Vector3<double>& tau0,
224+
const int ibt1,
225+
const ModuleBase::Vector3<double>& tau1,
226+
const int start1,
227+
const int nw1_tot,
228+
ModuleBase::Vector3<int> dR1,
229+
const int ibt2,
230+
const ModuleBase::Vector3<double>& tau2,
231+
const int start2,
232+
const int nw2_tot,
233+
ModuleBase::Vector3<int> dR2)
234+
{
235+
auto row_indexes = pv.get_indexes_row(ibt1);
236+
auto col_indexes = pv.get_indexes_col(ibt2);
237+
if (row_indexes.size() * col_indexes.size() == 0)
238+
{
239+
return; // to next loop
240+
}
241+
242+
int dRx = 0;
243+
int dRy = 0;
244+
int dRz = 0;
245+
if (std::is_same<T, std::complex<double>>::value)
246+
{
247+
dRx = (dR1 - dR2).x;
248+
dRy = (dR1 - dR2).y;
249+
dRz = (dR1 - dR2).z;
250+
}
251+
hamilt::AtomPair<double> dm_pair(ibt1, ibt2, dRx, dRy, dRz, &pv);
252+
this->dm_r->insert_pair(dm_pair);
253+
}
254+
);
255+
this->dm_r->allocate(nullptr, true);
256+
}
257+
209258
template <typename T>
210259
void LCAO_Deepks<T>::dpks_cal_e_delta_band(const std::vector<std::vector<T>>& dm, const int nks)
211260
{
212-
DeePKS_domain::cal_e_delta_band(dm, this->V_delta, nks, this->pv, this->e_delta_band);
261+
DeePKS_domain::cal_e_delta_band(dm, this->V_delta, nks, PARAM.inp.nspin, this->pv, this->e_delta_band);
213262
}
214263

215264
template class LCAO_Deepks<double>;

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ class LCAO_Deepks
8686
// index 0 for itself and index 1-3 for derivatives over x,y,z
8787
std::vector<hamilt::HContainer<double>*> phialpha;
8888

89+
// density matrix in real space
90+
hamilt::HContainer<double>* dm_r = nullptr;
91+
8992
// projected density matrix
9093
// [tot_Inl][2l+1][2l+1], here l is corresponding to inl;
9194
// [nat][nlm*nlm] for equivariant version
@@ -135,6 +138,12 @@ class LCAO_Deepks
135138
/// Allocate memory for correction to Hamiltonian
136139
void allocate_V_delta(const int nat, const int nks = 1);
137140

141+
/// Initialize the dm_r container
142+
void init_DMR(const UnitCell& ucell,
143+
const LCAO_Orbitals& orb,
144+
const Parallel_Orbitals& pv,
145+
const Grid_Driver& GridD);
146+
138147
//! a temporary interface for cal_e_delta_band
139148
void dpks_cal_e_delta_band(const std::vector<std::vector<T>>& dm, const int nks);
140149

0 commit comments

Comments
 (0)