Skip to content

Commit 5c815a1

Browse files
authored
Refactor: Change some functions in module_deepks into template. (#5731)
* Change cal_o_delta and cal_e_delta_band into templates. * Change cal_projected_DM to template. * Temporarily add the function of cal_f_delta_k into cal_f_delta_gamma. * Update FORCE_gamma.cpp * Update deepks_fgamma.cpp * Update deepks_force.h * Update deepks_force.h * Update deepks_force.h
1 parent c099d03 commit 5c815a1

File tree

13 files changed

+314
-469
lines changed

13 files changed

+314
-469
lines changed

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,21 @@ void Force_LCAO<double>::ftable(const bool isforce,
237237

238238
GlobalC::ld.cal_gedm(ucell.nat);
239239

240-
DeePKS_domain::cal_f_delta_gamma(dm_gamma,
241-
ucell,
242-
orb,
243-
gd,
240+
const int nks=1;
241+
DeePKS_domain::cal_f_delta_gamma(dm_gamma,
242+
ucell,
243+
orb,
244+
gd,
244245
*this->ParaV,
245246
GlobalC::ld.lmaxd,
247+
nks,
248+
kv->kvec_d,
246249
GlobalC::ld.nlm_save,
247250
GlobalC::ld.gedm,
248251
GlobalC::ld.inl_index,
249252
GlobalC::ld.F_delta,
250-
isstress,
251-
svnl_dalpha);
253+
isstress,
254+
svnl_dalpha);
252255

253256
#ifdef __MPI
254257
Parallel_Reduce::reduce_all(GlobalC::ld.F_delta.c, GlobalC::ld.F_delta.nr * GlobalC::ld.F_delta.nc);

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ void DeePKS<OperatorLCAO<double, double>>::contributeHR()
158158
{
159159
ModuleBase::timer::tick("DeePKS", "contributeHR");
160160
const Parallel_Orbitals* pv = this->hsk->get_pv();
161-
GlobalC::ld.cal_projected_DM(this->DM, *this->ucell, *ptr_orb_, *(this->gd));
161+
GlobalC::ld.cal_projected_DM<double>(this->DM, *this->ucell, *ptr_orb_, *(this->gd));
162162
GlobalC::ld.cal_descriptor(this->ucell->nat);
163163
GlobalC::ld.cal_gedm(this->ucell->nat);
164164
// recalculate the H_V_delta
@@ -186,7 +186,7 @@ void DeePKS<OperatorLCAO<std::complex<double>, double>>::contributeHR()
186186
{
187187
ModuleBase::timer::tick("DeePKS", "contributeHR");
188188

189-
GlobalC::ld.cal_projected_DM(this->DM, *this->ucell, *ptr_orb_, *this->gd);
189+
GlobalC::ld.cal_projected_DM<std::complex<double>>(this->DM, *this->ucell, *ptr_orb_, *this->gd);
190190
GlobalC::ld.cal_descriptor(this->ucell->nat);
191191
// calculate dE/dD
192192
GlobalC::ld.cal_gedm(this->ucell->nat);
@@ -219,7 +219,7 @@ void DeePKS<OperatorLCAO<std::complex<double>, std::complex<double>>>::contribut
219219
{
220220
ModuleBase::timer::tick("DeePKS", "contributeHR");
221221

222-
GlobalC::ld.cal_projected_DM(this->DM, *this->ucell, *ptr_orb_, *this->gd);
222+
GlobalC::ld.cal_projected_DM<std::complex<double>>(this->DM, *this->ucell, *ptr_orb_, *this->gd);
223223
GlobalC::ld.cal_descriptor(this->ucell->nat);
224224
// calculate dE/dD
225225
GlobalC::ld.cal_gedm(this->ucell->nat);

source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,13 @@ void LCAO_Deepks::del_v_delta_pdm_shell(const int nks,const int nlocal)
512512
return;
513513
}
514514

515-
void LCAO_Deepks::dpks_cal_e_delta_band(const std::vector<std::vector<double>>& dm, const int nks)
515+
template <typename TK>
516+
void LCAO_Deepks::dpks_cal_e_delta_band(const std::vector<std::vector<TK>>& dm, const int nks)
516517
{
517518
this->cal_e_delta_band(dm, nks);
518519
}
519520

520-
void LCAO_Deepks::dpks_cal_e_delta_band(const std::vector<std::vector<std::complex<double>>>& dm, const int nks)
521-
{
522-
this->cal_e_delta_band(dm, nks);
523-
}
521+
template void LCAO_Deepks::dpks_cal_e_delta_band<double>(const std::vector<std::vector<double>>& dm, const int nks);
522+
template void LCAO_Deepks::dpks_cal_e_delta_band<std::complex<double>>(const std::vector<std::vector<std::complex<double>>>& dm, const int nks);
524523

525524
#endif

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,10 @@ class LCAO_Deepks
284284

285285
// There are 6 subroutines in this file:
286286
// 1. cal_projected_DM, which is used for calculating pdm for gamma point calculation
287-
// 2. cal_projected_DM_k, counterpart of 1, for multi-k
288-
// 3. check_projected_dm, which prints pdm to descriptor.dat
287+
// 2. check_projected_dm, which prints pdm to descriptor.dat
289288

290-
// 4. cal_gdmx, calculating gdmx (and optionally gdm_epsl for stress) for gamma point
291-
// 5. check_gdmx, which prints gdmx to a series of .dat files
289+
// 3. cal_gdmx, calculating gdmx (and optionally gdm_epsl for stress) for gamma point
290+
// 4. check_gdmx, which prints gdmx to a series of .dat files
292291

293292
public:
294293
/**
@@ -299,28 +298,14 @@ class LCAO_Deepks
299298
* 2. SCF calculation of DeePKS with init_chg = file and pdm has been read for restarting SCF
300299
* 3. Relax/Cell-Relax/MD calculation, non-first step will use the convergence pdm from the last step as initial pdm
301300
*/
302-
void cal_projected_DM(const elecstate::DensityMatrix<double, double>* dm,
303-
const UnitCell& ucell,
304-
const LCAO_Orbitals& orb,
305-
const Grid_Driver& GridD);
306-
307-
void cal_projected_DM(const elecstate::DensityMatrix<std::complex<double>, double>* dm,
301+
template <typename TK>
302+
void cal_projected_DM(const elecstate::DensityMatrix<TK, double>* dm,
308303
const UnitCell& ucell,
309304
const LCAO_Orbitals& orb,
310305
const Grid_Driver& GridD);
311306

312307
void check_projected_dm();
313308

314-
void cal_projected_DM_equiv(const elecstate::DensityMatrix<double, double>* dm,
315-
const UnitCell& ucell,
316-
const LCAO_Orbitals& orb,
317-
const Grid_Driver& GridD);
318-
319-
void cal_projected_DM_k_equiv(const elecstate::DensityMatrix<std::complex<double>, double>* dm,
320-
const UnitCell& ucell,
321-
const LCAO_Orbitals& orb,
322-
const Grid_Driver& GridD);
323-
324309
// calculate the gradient of pdm with regard to atomic positions
325310
// d/dX D_{Inl,mm'}
326311
template <typename TK>
@@ -358,21 +343,18 @@ class LCAO_Deepks
358343
// tr (rho * V_delta)
359344

360345
// Four subroutines are contained in the file:
361-
// 5. cal_e_delta_band : calculates e_delta_bands for gamma only
362-
// 6. cal_e_delta_band_k : counterpart of 4, for multi-k
346+
// 5. cal_e_delta_band : calculates e_delta_bands
363347

364348
public:
365349
/// calculate tr(\rho V_delta)
366350
// void cal_e_delta_band(const std::vector<ModuleBase::matrix>& dm/**<[in] density matrix*/);
367-
void cal_e_delta_band(const std::vector<std::vector<double>>& dm /**<[in] density matrix*/, const int /*nks*/);
368-
// void cal_e_delta_band_k(const std::vector<ModuleBase::ComplexMatrix>& dm/**<[in] density matrix*/,
369-
// const int nks);
370-
void cal_e_delta_band(const std::vector<std::vector<std::complex<double>>>& dm /**<[in] density matrix*/,
371-
const int nks);
351+
template <typename TK>
352+
void cal_e_delta_band(const std::vector<std::vector<TK>>& dm /**<[in] density matrix*/, const int nks);
372353

373354
//! a temporary interface for cal_e_delta_band and cal_e_delta_band_k
374-
void dpks_cal_e_delta_band(const std::vector<std::vector<double>>& dm, const int nks);
375-
void dpks_cal_e_delta_band(const std::vector<std::vector<std::complex<double>>>& dm, const int nks);
355+
template <typename TK>
356+
void dpks_cal_e_delta_band(const std::vector<std::vector<TK>>& dm, const int nks);
357+
376358

377359
//-------------------
378360
// LCAO_deepks_odelta.cpp
@@ -381,17 +363,11 @@ class LCAO_Deepks
381363
// This file contains subroutines for calculating O_delta,
382364
// which corresponds to the correction of the band gap.
383365

384-
// There are two subroutines in this file:
385-
// 1. cal_o_delta, which is used for gamma point calculation
386-
// 2. cal_o_delta_k, which is used for multi-k calculation
387-
388366
public:
389-
void cal_o_delta(const std::vector<std::vector<ModuleBase::matrix>>&
367+
template <typename TK, typename TH>
368+
void cal_o_delta(const std::vector<std::vector<TH>>&
390369
dm_hl /**<[in] modified density matrix that contains HOMO and LUMO only*/,
391370
const int nks);
392-
void cal_o_delta(const std::vector<std::vector<ModuleBase::ComplexMatrix>>&
393-
dm_hl /**<[in] modified density matrix that contains HOMO and LUMO only*/,
394-
const int nks);
395371

396372
//-------------------
397373
// LCAO_deepks_torch.cpp

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
117117
}
118118

119119
ld->cal_orbital_precalc<TK,TH>(dm_bandgap, nat, nks, kvec_d, ucell, orb, GridD);
120-
ld->cal_o_delta(dm_bandgap, nks);
120+
ld->cal_o_delta<TK,TH>(dm_bandgap, nks);
121121

122122
// save obase and orbital_precalc
123123
LCAO_deepks_io::save_npy_orbital_precalc(nat,
@@ -249,7 +249,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
249249
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
250250
if(!PARAM.inp.deepks_scf)
251251
{
252-
ld->cal_projected_DM(dm, ucell, orb, GridD);
252+
ld->cal_projected_DM<TK>(dm, ucell, orb, GridD);
253253
}
254254

255255
ld->check_projected_dm(); // print out the projected dm for NSCF calculaiton

source/module_hamilt_lcao/module_deepks/LCAO_deepks_odelta.cpp

Lines changed: 39 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,63 +4,29 @@
44
//which is defind as sum_mu,nu rho^{hl}_mu,nu <chi_mu|alpha>V(D)<alpha|chi_nu>
55
//where rho^{hl}_mu,nu = C_{L\mu}C_{L\nu} - C_{H\mu}C_{H\nu}, L for LUMO, H for HOMO
66

7-
//There are two subroutines in this file:
8-
//1. cal_o_delta, which is used for gamma point calculation
9-
//2. cal_o_delta_k, which is used for multi-k calculation
10-
117
#ifdef __DEEPKS
128

139
#include "LCAO_deepks.h"
1410
#include "module_base/parallel_reduce.h"
1511

16-
void LCAO_Deepks::cal_o_delta(const std::vector<std::vector<ModuleBase::matrix>>& dm_hl, const int nks)
12+
template <typename TK, typename TH>
13+
void LCAO_Deepks::cal_o_delta(const std::vector<std::vector<TH>>& dm_hl, const int nks)
1714
{
1815
ModuleBase::TITLE("LCAO_Deepks", "cal_o_delta");
19-
this->o_delta.zero_out();
20-
for (int hl = 0; hl < 1; ++hl)
21-
{
22-
for (int i = 0; i < PARAM.globalv.nlocal; ++i)
23-
{
24-
for (int j = 0; j < PARAM.globalv.nlocal; ++j)
25-
{
26-
const int mu = pv->global2local_row(j);
27-
const int nu = pv->global2local_col(i);
28-
29-
if (mu >= 0 && nu >= 0)
30-
{
31-
const int index = nu * pv->nrow + mu;
32-
for (int is = 0; is < PARAM.inp.nspin; ++is)
33-
{
34-
this->o_delta(0,hl) += dm_hl[hl][is](nu, mu) * this->H_V_delta[0][index];
35-
}
36-
}
37-
}
38-
}
39-
Parallel_Reduce::reduce_all(this->o_delta(0, hl));
40-
}
41-
return;
42-
}
4316

44-
45-
//calculating the correction of (LUMO-HOMO) energies, i.e., band gap corrections
46-
//for multi_k calculations
47-
void LCAO_Deepks::cal_o_delta(const std::vector<std::vector<ModuleBase::ComplexMatrix>>& dm_hl,
48-
const int nks)
49-
{
50-
ModuleBase::TITLE("LCAO_Deepks", "cal_o_delta_k");
51-
52-
for(int ik=0; ik<nks; ik++)
17+
this->o_delta.zero_out();
18+
for (int ik = 0; ik < nks; ik++)
5319
{
54-
for (int hl=0; hl<1; hl++)
20+
for (int hl = 0; hl < 1; ++hl)
5521
{
56-
std::complex<double> o_delta_k=std::complex<double>(0.0,0.0);
22+
TK o_delta_tmp = TK(0.0);
5723
for (int i = 0; i < PARAM.globalv.nlocal; ++i)
5824
{
5925
for (int j = 0; j < PARAM.globalv.nlocal; ++j)
6026
{
6127
const int mu = pv->global2local_row(j);
6228
const int nu = pv->global2local_col(i);
63-
29+
6430
if (mu >= 0 && nu >= 0)
6531
{
6632
int iic;
@@ -72,16 +38,40 @@ void LCAO_Deepks::cal_o_delta(const std::vector<std::vector<ModuleBase::ComplexM
7238
{
7339
iic = mu * pv->ncol + nu;
7440
}
75-
o_delta_k += dm_hl[hl][ik](nu, mu) * this->H_V_delta_k[ik][iic];
41+
if constexpr (std::is_same<TK, double>::value)
42+
{
43+
for (int is = 0; is < PARAM.inp.nspin; ++is)
44+
{
45+
o_delta_tmp += dm_hl[hl][is](nu, mu) * this->H_V_delta[0][iic];
46+
}
47+
}
48+
else
49+
{
50+
o_delta_tmp += dm_hl[hl][ik](nu, mu) * this->H_V_delta_k[ik][iic];
51+
}
7652
}
77-
} //end j
78-
} //end i
79-
Parallel_Reduce::reduce_all(o_delta_k);
80-
this->o_delta(ik,hl) = o_delta_k.real();
81-
}// end hl
82-
}// end nks
83-
53+
}
54+
}
55+
Parallel_Reduce::reduce_all(o_delta_tmp);
56+
if constexpr (std::is_same<TK, double>::value)
57+
{
58+
this->o_delta(ik,hl) = o_delta_tmp;
59+
}
60+
else
61+
{
62+
this->o_delta(ik,hl) = o_delta_tmp.real();
63+
}
64+
}
65+
}
8466
return;
8567
}
8668

69+
template void LCAO_Deepks::cal_o_delta<double, ModuleBase::matrix>(
70+
const std::vector<std::vector<ModuleBase::matrix>>& dm_hl,
71+
const int nks);
72+
73+
template void LCAO_Deepks::cal_o_delta<std::complex<double>, ModuleBase::ComplexMatrix>(
74+
const std::vector<std::vector<ModuleBase::ComplexMatrix>>& dm_hl,
75+
const int nks);
76+
8777
#endif

0 commit comments

Comments
 (0)