Skip to content

Commit d73b1d2

Browse files
Refactor: Combine gamma-only and multi-k versions of some functions in DeePKS. (#5717)
* Add support for INPUT deepks_v_delta>0 in multi-k points DeePKS calculations * Refactor: Change LCAO_Deepks_Interface to template class. * Remove the h_mat and h_mat_k variables in LCAO_Deepks and change H_V_delta to form consistent with H_V_delta_k. * Change functions in deepks_hmat to template. * Combine gamma-only and multi-k for v_delta_precalc. * Change functions about v_delta_precalc and psialpha in deepks_v_delta calculations to templates. * Change save_npy_h to template. * Change some functions in LCAO_deepks_io to templates. * Remove ld.V_deltaR. * Change cal_orbital_precalc to template. * Remove orbital_precalc_k.cpp. * Change cal_gdmx into template function. * [pre-commit.ci lite] apply automatic fixes * Update LCAO_deepks_interface.cpp * Update FORCE_STRESS.cpp * Update FORCE_gamma.cpp * Update deepks_lcao.cpp * Update LCAO_deepks.cpp * Update LCAO_deepks.cpp * Update LCAO_deepks.h --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 3b3466e commit d73b1d2

25 files changed

+946
-1888
lines changed

source/Makefile.Objects

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,11 @@ OBJS_DEEPKS=LCAO_deepks.o\
201201
deepks_hmat.o\
202202
LCAO_deepks_interface.o\
203203
orbital_precalc.o\
204-
orbital_precalc_k.o\
205204
cal_gdmx.o\
206-
cal_gdmx_k.o\
207205
cal_gedm.o\
208206
cal_gvx.o\
209207
cal_descriptor.o\
210208
v_delta_precalc.o\
211-
v_delta_precalc_k.o\
212209

213210

214211
OBJS_ELECSTAT=elecstate.o\

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
961961
// 6) write Hamiltonian and Overlap matrix
962962
for (int ik = 0; ik < this->kv.get_nks(); ++ik)
963963
{
964-
if (PARAM.inp.out_mat_hs[0] || PARAM.inp.deepks_v_delta)
964+
if (PARAM.inp.out_mat_hs[0])
965965
{
966966
this->p_hamilt->updateHk(ik);
967967
}
@@ -1000,12 +1000,6 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
10001000
this->pv,
10011001
GlobalV::DRANK);
10021002
}
1003-
#ifdef __DEEPKS
1004-
if (PARAM.inp.deepks_out_labels && PARAM.inp.deepks_v_delta)
1005-
{
1006-
DeePKS_domain::save_h_mat(h_mat.p, this->pv.nloc, ik);
1007-
}
1008-
#endif
10091003
}
10101004
}
10111005

@@ -1023,24 +1017,30 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
10231017

10241018
//! 8) Write DeePKS information
10251019
#ifdef __DEEPKS
1026-
std::shared_ptr<LCAO_Deepks> ld_shared_ptr(&GlobalC::ld, [](LCAO_Deepks*) {});
1027-
LCAO_Deepks_Interface LDI = LCAO_Deepks_Interface(ld_shared_ptr);
1028-
ModuleBase::timer::tick("ESolver_KS_LCAO", "out_deepks_labels");
1029-
LDI.out_deepks_labels(this->pelec->f_en.etot,
1030-
this->pelec->klist->get_nks(),
1031-
ucell.nat,
1032-
PARAM.globalv.nlocal,
1033-
this->pelec->ekb,
1034-
this->pelec->klist->kvec_d,
1035-
ucell,
1036-
orb_,
1037-
GlobalC::GridD,
1038-
&(this->pv),
1039-
*(this->psi),
1040-
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
1041-
PARAM.inp.deepks_v_delta);
1042-
1043-
ModuleBase::timer::tick("ESolver_KS_LCAO", "out_deepks_labels");
1020+
if (this->psi != nullptr && (istep % PARAM.inp.out_interval == 0))
1021+
{
1022+
hamilt::HamiltLCAO<TK, TR>* p_ham_deepks
1023+
= dynamic_cast<hamilt::HamiltLCAO<TK, TR>*>(this->p_hamilt);
1024+
std::shared_ptr<LCAO_Deepks> ld_shared_ptr(&GlobalC::ld, [](LCAO_Deepks*) {});
1025+
LCAO_Deepks_Interface<TK, TR> LDI(ld_shared_ptr);
1026+
1027+
ModuleBase::timer::tick("ESolver_KS_LCAO", "out_deepks_labels");
1028+
LDI.out_deepks_labels(this->pelec->f_en.etot,
1029+
this->pelec->klist->get_nks(),
1030+
ucell.nat,
1031+
PARAM.globalv.nlocal,
1032+
this->pelec->ekb,
1033+
this->pelec->klist->kvec_d,
1034+
ucell,
1035+
orb_,
1036+
GlobalC::GridD,
1037+
&(this->pv),
1038+
*(this->psi),
1039+
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
1040+
p_ham_deepks);
1041+
1042+
ModuleBase::timer::tick("ESolver_KS_LCAO", "out_deepks_labels");
1043+
}
10441044
#endif
10451045

10461046
//! 9) Perform RDMFT calculations

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ void Force_Stress_LCAO<T>::getForceStress(const bool isforce,
522522
{
523523
const std::vector<std::vector<double>>& dm_gamma
524524
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();
525-
GlobalC::ld.cal_gdmx(dm_gamma[0], ucell, orb, GlobalC::GridD, isstress);
525+
GlobalC::ld.cal_gdmx(dm_gamma, ucell, orb, GlobalC::GridD, kv.get_nks(), kv.kvec_d, isstress);
526526
}
527527
else
528528
{
@@ -531,13 +531,13 @@ void Force_Stress_LCAO<T>::getForceStress(const bool isforce,
531531
->get_DM()
532532
->get_DMK_vector();
533533

534-
GlobalC::ld.cal_gdmx_k(dm_k,
535-
ucell,
536-
orb,
537-
GlobalC::GridD,
538-
kv.get_nks(),
539-
kv.kvec_d,
540-
isstress);
534+
GlobalC::ld.cal_gdmx(dm_k,
535+
ucell,
536+
orb,
537+
GlobalC::GridD,
538+
kv.get_nks(),
539+
kv.kvec_d,
540+
isstress);
541541
}
542542
if (PARAM.inp.deepks_out_unittest)
543543
{

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,15 +260,16 @@ void Force_LCAO<double>::ftable(const bool isforce,
260260

261261
if (PARAM.inp.deepks_out_unittest)
262262
{
263-
LCAO_deepks_io::print_dm(dm_gamma[0], PARAM.globalv.nlocal, this->ParaV->nrow);
263+
const int nks = 1; // 1 for gamma-only
264+
LCAO_deepks_io::print_dm(nks, PARAM.globalv.nlocal, this->ParaV->nrow, dm_gamma);
264265

265266
GlobalC::ld.check_projected_dm();
266267

267268
GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir);
268269

269270
GlobalC::ld.check_gedm();
270271

271-
GlobalC::ld.cal_e_delta_band(dm_gamma);
272+
GlobalC::ld.cal_e_delta_band(dm_gamma,nks);
272273

273274
std::ofstream ofs("E_delta_bands.dat");
274275
ofs << std::setprecision(10) << GlobalC::ld.e_delta_band;

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ void DeePKS<OperatorLCAO<std::complex<double>, double>>::contributeHR()
186186
{
187187
ModuleBase::timer::tick("DeePKS", "contributeHR");
188188

189-
GlobalC::ld.cal_projected_DM_k(this->DM, *this->ucell, *ptr_orb_, GlobalC::GridD);
189+
GlobalC::ld.cal_projected_DM(this->DM, *this->ucell, *ptr_orb_, GlobalC::GridD);
190190
GlobalC::ld.cal_descriptor(this->ucell->nat);
191191
// calculate dE/dD
192192
GlobalC::ld.cal_gedm(this->ucell->nat);
@@ -219,7 +219,7 @@ void DeePKS<OperatorLCAO<std::complex<double>, std::complex<double>>>::contribut
219219
{
220220
ModuleBase::timer::tick("DeePKS", "contributeHR");
221221

222-
GlobalC::ld.cal_projected_DM_k(this->DM, *this->ucell, *ptr_orb_, GlobalC::GridD);
222+
GlobalC::ld.cal_projected_DM(this->DM, *this->ucell, *ptr_orb_, GlobalC::GridD);
223223
GlobalC::ld.cal_descriptor(this->ucell->nat);
224224
// calculate dE/dD
225225
GlobalC::ld.cal_gedm(this->ucell->nat);
@@ -497,7 +497,7 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::cal_HR_IJR(const double* hr_i
497497

498498
inline void get_h_delta_k(int ik, double*& h_delta_k)
499499
{
500-
h_delta_k = GlobalC::ld.H_V_delta.data();
500+
h_delta_k = GlobalC::ld.H_V_delta[ik].data();
501501
return;
502502
}
503503
inline void get_h_delta_k(int ik, std::complex<double>*& h_delta_k)

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@ if(ENABLE_DEEPKS)
1313
deepks_hmat.cpp
1414
LCAO_deepks_interface.cpp
1515
orbital_precalc.cpp
16-
orbital_precalc_k.cpp
1716
cal_gdmx.cpp
18-
cal_gdmx_k.cpp
1917
cal_gedm.cpp
2018
cal_gvx.cpp
2119
cal_descriptor.cpp
2220
v_delta_precalc.cpp
23-
v_delta_precalc_k.cpp
2421
)
2522

2623
add_library(

source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
//4. subroutines that are related to V_delta:
1818
// - allocate_V_delta : allocates H_V_delta; if calculating force, it also calls
1919
// init_gdmx, as well as allocating F_delta
20-
// - allocate_V_deltaR : allcoates H_V_deltaR, for multi-k calculations
2120

2221
#ifdef __DEEPKS
2322

@@ -35,7 +34,6 @@ LCAO_Deepks::LCAO_Deepks()
3534
alpha_index = new ModuleBase::IntArray[1];
3635
inl_index = new ModuleBase::IntArray[1];
3736
inl_l = nullptr;
38-
H_V_deltaR = nullptr;
3937
gedm = nullptr;
4038
}
4139

@@ -45,7 +43,6 @@ LCAO_Deepks::~LCAO_Deepks()
4543
delete[] alpha_index;
4644
delete[] inl_index;
4745
delete[] inl_l;
48-
delete[] H_V_deltaR;
4946

5047
//=======1. to use deepks, pdm is required==========
5148
//delete pdm**
@@ -92,7 +89,10 @@ void LCAO_Deepks::init(
9289

9390
int tot_inl = tot_inl_per_atom * nat;
9491

95-
if(PARAM.inp.deepks_equiv) tot_inl = nat;
92+
if(PARAM.inp.deepks_equiv)
93+
{
94+
tot_inl = nat;
95+
}
9696

9797
this->lmaxd = lm;
9898
this->nmaxd = nm;
@@ -143,25 +143,6 @@ void LCAO_Deepks::init(
143143

144144
this->pv = &pv_in;
145145

146-
if(PARAM.inp.deepks_v_delta)
147-
{
148-
//allocate and init h_mat
149-
if(PARAM.globalv.gamma_only_local)
150-
{
151-
int nloc=this->pv->nloc;
152-
this->h_mat.resize(nloc,0.0);
153-
}
154-
else
155-
{
156-
int nloc=this->pv->nloc;
157-
this->h_mat_k.resize(nks);
158-
for (int ik = 0; ik < nks; ik++)
159-
{
160-
this->h_mat_k[ik].resize(nloc,std::complex<double>(0.0,0.0));
161-
}
162-
}
163-
}
164-
165146
return;
166147
}
167148

@@ -335,8 +316,9 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
335316
//initialize the H matrix H_V_delta
336317
if(PARAM.globalv.gamma_only_local)
337318
{
338-
this->H_V_delta.resize(pv->nloc);
339-
ModuleBase::GlobalFunc::ZEROS(this->H_V_delta.data(), pv->nloc);
319+
H_V_delta.resize(1); // the first dimension is for the consistence with H_V_delta_k
320+
this->H_V_delta[0].resize(pv->nloc);
321+
ModuleBase::GlobalFunc::ZEROS(this->H_V_delta[0].data(), pv->nloc);
340322
}
341323
else
342324
{
@@ -387,15 +369,6 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
387369
return;
388370
}
389371

390-
void LCAO_Deepks::allocate_V_deltaR(const int nnr)
391-
{
392-
ModuleBase::TITLE("LCAO_Deepks", "allocate_V_deltaR");
393-
GlobalV::ofs_running << nnr << std::endl;
394-
delete[] H_V_deltaR;
395-
H_V_deltaR = new double[nnr];
396-
ModuleBase::GlobalFunc::ZEROS(H_V_deltaR, nnr);
397-
}
398-
399372
void LCAO_Deepks::init_orbital_pdm_shell(const int nks)
400373
{
401374

@@ -541,12 +514,12 @@ void LCAO_Deepks::del_v_delta_pdm_shell(const int nks,const int nlocal)
541514

542515
void LCAO_Deepks::dpks_cal_e_delta_band(const std::vector<std::vector<double>>& dm, const int nks)
543516
{
544-
this->cal_e_delta_band(dm);
517+
this->cal_e_delta_band(dm, nks);
545518
}
546519

547520
void LCAO_Deepks::dpks_cal_e_delta_band(const std::vector<std::vector<std::complex<double>>>& dm, const int nks)
548521
{
549-
this->cal_e_delta_band_k(dm, nks);
522+
this->cal_e_delta_band(dm, nks);
550523
}
551524

552525
#endif

0 commit comments

Comments
 (0)