Skip to content

Commit 0aba423

Browse files
ErjieWuQianruipku
andauthored
Refactor: Remove GlobalV in DeePKS and simplify some functions. (#5952)
* Remove GlobalV in module_deepks. * Simplify some functions in DeePKS. Change LCAO_Deepks class into template. * Combine H_V_delta and H_V_delta_k and rename as V_delta. * Rearrange the output order for Output_HContainer. * Update output_hcontainer.cpp. * Fix size mismatch for equivariant version DeePKS. * update esolver_ks_lcao.cpp. * Fix merge bug. * Update reference to match new output order. --------- Co-authored-by: Qianrui Liu <[email protected]>
1 parent 6b928c9 commit 0aba423

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1363
-1340
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ OBJS_DEEPKS=LCAO_deepks.o\
200200
deepks_descriptor.o\
201201
deepks_force.o\
202202
deepks_fpre.o\
203+
deepks_iterate.o\
203204
deepks_spre.o\
204205
deepks_orbital.o\
205206
deepks_orbpre.o\

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
211211

212212
#ifdef __DEEPKS
213213
// 10) initialize deepks
214-
LCAO_domain::DeePKS_init(ucell, pv, this->kv.get_nks(), orb_, this->ld);
214+
LCAO_domain::DeePKS_init(ucell, pv, this->kv.get_nks(), orb_, this->ld, GlobalV::ofs_running);
215215
if (PARAM.inp.deepks_scf)
216216
{
217217
// load the DeePKS model from deep neural network
@@ -220,6 +220,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
220220
DeePKS_domain::read_pdm((PARAM.inp.init_chg == "file"),
221221
PARAM.inp.deepks_equiv,
222222
ld.init_pdm,
223+
ucell.nat,
223224
orb_.Alpha[0].getTotal_nchi() * ucell.nat,
224225
ld.lmaxd,
225226
ld.inl_l,
@@ -245,8 +246,8 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
245246
"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%"
246247
"%%%%%%%%%%%%%%%%%%%%%%%%%%"
247248
<< std::endl;
248-
std::cout << " Warning: nks (" << this->kv.get_nks() << ") is not divisible by kpar (" << PARAM.globalv.kpar_lcao
249-
<< ")." << std::endl;
249+
std::cout << " Warning: nks (" << this->kv.get_nks() << ") is not divisible by kpar ("
250+
<< PARAM.globalv.kpar_lcao << ")." << std::endl;
250251
std::cout << " This may lead to poor load balance. It is strongly suggested to" << std::endl;
251252
std::cout << " set nks to be divisible by kpar, but if this is really what" << std::endl;
252253
std::cout << " you want, please ignore this warning." << std::endl;

source/module_esolver/esolver_ks_lcao.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
// for grid integration
66
#include "module_hamilt_lcao/module_gint/gint_gamma.h"
77
#include "module_hamilt_lcao/module_gint/gint_k.h"
8-
#include "module_hamilt_lcao/module_gint/temp_gint/gint_info.h"
98
#include "module_hamilt_lcao/module_gint/temp_gint/gint.h"
9+
#include "module_hamilt_lcao/module_gint/temp_gint/gint_info.h"
1010
#ifdef __DEEPKS
1111
#include "module_hamilt_lcao/module_deepks/LCAO_deepks.h"
1212
#endif
@@ -100,7 +100,7 @@ class ESolver_KS_LCAO : public ESolver_KS<TK>
100100
//---------------------------------------------------------------------
101101

102102
#ifdef __DEEPKS
103-
LCAO_Deepks ld;
103+
LCAO_Deepks<TK> ld;
104104
#endif
105105

106106
#ifdef __EXX

source/module_esolver/lcao_after_scf.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
220220
if (this->psi != nullptr && (istep % PARAM.inp.out_interval == 0))
221221
{
222222
hamilt::HamiltLCAO<TK, TR>* p_ham_deepks = dynamic_cast<hamilt::HamiltLCAO<TK, TR>*>(this->p_hamilt);
223-
std::shared_ptr<LCAO_Deepks> ld_shared_ptr(&ld, [](LCAO_Deepks*) {});
223+
std::shared_ptr<LCAO_Deepks<TK>> ld_shared_ptr(&ld, [](LCAO_Deepks<TK>*) {});
224224
LCAO_Deepks_Interface<TK, TR> deepks_interface(ld_shared_ptr);
225225

226226
deepks_interface.out_deepks_labels(this->pelec->f_en.etot,
@@ -235,7 +235,8 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
235235
&(this->pv),
236236
*(this->psi),
237237
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
238-
p_ham_deepks);
238+
p_ham_deepks,
239+
GlobalV::MY_RANK);
239240
}
240241
#endif
241242

source/module_hamilt_lcao/hamilt_lcaodft/FORCE.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class Force_LCAO
6666
#ifdef __DEEPKS
6767
ModuleBase::matrix& fvnl_dalpha,
6868
ModuleBase::matrix& svnl_dalpha,
69-
LCAO_Deepks& ld,
69+
LCAO_Deepks<T>& ld,
7070
#endif
7171
typename TGint<T>::type& gint,
7272
const TwoCenterBundle& two_center_bundle,

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
5252
ModulePW::PW_Basis* rhopw,
5353
surchem& solvent,
5454
#ifdef __DEEPKS
55-
LCAO_Deepks& ld,
55+
LCAO_Deepks<T>& ld,
5656
#endif
5757
#ifdef __EXX
5858
Exx_LRI<double>& exx_lri_double,
@@ -838,7 +838,7 @@ void Force_Stress_LCAO<double>::integral_part(const bool isGammaOnly,
838838
#if __DEEPKS
839839
ModuleBase::matrix& fvnl_dalpha,
840840
ModuleBase::matrix& svnl_dalpha,
841-
LCAO_Deepks& ld,
841+
LCAO_Deepks<double>& ld,
842842
#endif
843843
Gint_Gamma& gint_gamma, // mohan add 2024-04-01
844844
Gint_k& gint_k, // mohan add 2024-04-01
@@ -895,7 +895,7 @@ void Force_Stress_LCAO<std::complex<double>>::integral_part(const bool isGammaOn
895895
#if __DEEPKS
896896
ModuleBase::matrix& fvnl_dalpha,
897897
ModuleBase::matrix& svnl_dalpha,
898-
LCAO_Deepks& ld,
898+
LCAO_Deepks<std::complex<double>>& ld,
899899
#endif
900900
Gint_Gamma& gint_gamma,
901901
Gint_k& gint_k,

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class Force_Stress_LCAO
5050
ModulePW::PW_Basis* rhopw,
5151
surchem& solvent,
5252
#ifdef __DEEPKS
53-
LCAO_Deepks& ld,
53+
LCAO_Deepks<T>& ld,
5454
#endif
5555
#ifdef __EXX
5656
Exx_LRI<double>& exx_lri_double,
@@ -99,7 +99,7 @@ class Force_Stress_LCAO
9999
#if __DEEPKS
100100
ModuleBase::matrix& fvnl_dalpha,
101101
ModuleBase::matrix& svnl_dalpha,
102-
LCAO_Deepks& ld,
102+
LCAO_Deepks<T>& ld,
103103
#endif
104104
Gint_Gamma& gint_gamma,
105105
Gint_k& gint_k,

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ void Force_LCAO<double>::ftable(const bool isforce,
190190
#ifdef __DEEPKS
191191
ModuleBase::matrix& fvnl_dalpha,
192192
ModuleBase::matrix& svnl_dalpha,
193-
LCAO_Deepks& ld,
193+
LCAO_Deepks<double>& ld,
194194
#endif
195195
TGint<double>::type& gint,
196196
const TwoCenterBundle& two_center_bundle,
@@ -252,21 +252,7 @@ void Force_LCAO<double>::ftable(const bool isforce,
252252
{
253253
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();
254254

255-
// These calculations have been done in LCAO_Deepks_Interface in after_scf
256-
// std::vector<torch::Tensor> descriptor;
257-
// DeePKS_domain::cal_descriptor(ucell.nat, ld.inlmax, ld.inl_l, ld.pdm, descriptor, ld.des_per_atom);
258-
// DeePKS_domain::cal_edelta_gedm(ucell.nat,
259-
// ld.lmaxd,
260-
// ld.nmaxd,
261-
// ld.inlmax,
262-
// ld.des_per_atom,
263-
// ld.inl_l,
264-
// descriptor,
265-
// ld.pdm,
266-
// ld.model_deepks,
267-
// ld.gedm,
268-
// ld.E_delta);
269-
255+
// No need to update E_delta here since it have been done in LCAO_Deepks_Interface in after_scf
270256
const int nks = 1;
271257
DeePKS_domain::cal_f_delta<double>(dm_gamma,
272258
ucell,

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
284284
#ifdef __DEEPKS
285285
ModuleBase::matrix& fvnl_dalpha,
286286
ModuleBase::matrix& svnl_dalpha,
287-
LCAO_Deepks& ld,
287+
LCAO_Deepks<std::complex<double>>& ld,
288288
#endif
289289
TGint<std::complex<double>>::type& gint,
290290
const TwoCenterBundle& two_center_bundle,
@@ -347,21 +347,7 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
347347
{
348348
const std::vector<std::vector<std::complex<double>>>& dm_k = dm->get_DMK_vector();
349349

350-
// These calculations have been done in LCAO_Deepks_Interface in after_scf
351-
// std::vector<torch::Tensor> descriptor;
352-
// DeePKS_domain::cal_descriptor(ucell.nat, ld.inlmax, ld.inl_l, ld.pdm, descriptor, ld.des_per_atom);
353-
// DeePKS_domain::cal_edelta_gedm(ucell.nat,
354-
// ld.lmaxd,
355-
// ld.nmaxd,
356-
// ld.inlmax,
357-
// ld.des_per_atom,
358-
// ld.inl_l,
359-
// descriptor,
360-
// ld.pdm,
361-
// ld.model_deepks,
362-
// ld.gedm,
363-
// ld.E_delta);
364-
350+
// No need to update E_delta since it have been done in LCAO_Deepks_Interface in after_scf
365351
DeePKS_domain::cal_f_delta<std::complex<double>>(dm_k,
366352
ucell,
367353
orb,

source/module_hamilt_lcao/hamilt_lcaodft/LCAO_allocate.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ namespace LCAO_domain
88
{
99
#ifdef __DEEPKS
1010
// It seems it is only related to DeePKS, so maybe we should move it to DeeKS_domain
11+
template <typename T>
1112
void DeePKS_init(const UnitCell& ucell,
1213
Parallel_Orbitals& pv,
1314
const int& nks,
1415
const LCAO_Orbitals& orb,
15-
LCAO_Deepks& ld)
16+
LCAO_Deepks<T>& ld,
17+
std::ofstream& ofs)
1618
{
1719
ModuleBase::TITLE("LCAO_domain", "DeePKS_init");
1820
// preparation for DeePKS
@@ -26,7 +28,7 @@ void DeePKS_init(const UnitCell& ucell,
2628
na[it] = ucell.atoms[it].na;
2729
}
2830

29-
ld.init(orb, ucell.nat, ucell.ntype, nks, pv, na);
31+
ld.init(orb, ucell.nat, ucell.ntype, nks, pv, na, ofs);
3032

3133
if (PARAM.inp.deepks_scf)
3234
{
@@ -35,5 +37,19 @@ void DeePKS_init(const UnitCell& ucell,
3537
}
3638
return;
3739
}
40+
41+
template void DeePKS_init<double>(const UnitCell& ucell,
42+
Parallel_Orbitals& pv,
43+
const int& nks,
44+
const LCAO_Orbitals& orb,
45+
LCAO_Deepks<double>& ld,
46+
std::ofstream& ofs);
47+
48+
template void DeePKS_init<std::complex<double>>(const UnitCell& ucell,
49+
Parallel_Orbitals& pv,
50+
const int& nks,
51+
const LCAO_Orbitals& orb,
52+
LCAO_Deepks<std::complex<double>>& ld,
53+
std::ofstream& ofs);
3854
#endif
3955
} // namespace LCAO_domain

0 commit comments

Comments
 (0)