Skip to content

Commit b809ce6

Browse files
authored
Update pdm before outputting DeePKS labels (#5857)
* rename cal_gedm as cal_edelta_gedm * Update pdm and other properties before outputting deeepks labels * fix bug of include * change test ref because of updating pdm
1 parent df96394 commit b809ce6

File tree

11 files changed

+118
-97
lines changed

11 files changed

+118
-97
lines changed

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,20 +251,21 @@ void Force_LCAO<double>::ftable(const bool isforce,
251251
if (PARAM.inp.deepks_scf)
252252
{
253253
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();
254-
std::vector<torch::Tensor> descriptor;
255-
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
256-
DeePKS_domain::cal_descriptor(ucell.nat, ld.inlmax, ld.inl_l, ld.pdm, descriptor, ld.des_per_atom);
257-
DeePKS_domain::cal_gedm(ucell.nat,
258-
ld.lmaxd,
259-
ld.nmaxd,
260-
ld.inlmax,
261-
ld.des_per_atom,
262-
ld.inl_l,
263-
descriptor,
264-
ld.pdm,
265-
ld.model_deepks,
266-
ld.gedm,
267-
ld.E_delta);
254+
255+
// These calculations have been done in LCAO_Deepks_Interface in after_scf
256+
// std::vector<torch::Tensor> descriptor;
257+
// DeePKS_domain::cal_descriptor(ucell.nat, ld.inlmax, ld.inl_l, ld.pdm, descriptor, ld.des_per_atom);
258+
// DeePKS_domain::cal_edelta_gedm(ucell.nat,
259+
// ld.lmaxd,
260+
// ld.nmaxd,
261+
// ld.inlmax,
262+
// ld.des_per_atom,
263+
// ld.inl_l,
264+
// descriptor,
265+
// ld.pdm,
266+
// ld.model_deepks,
267+
// ld.gedm,
268+
// ld.E_delta);
268269

269270
const int nks = 1;
270271
DeePKS_domain::cal_f_delta<double>(dm_gamma,

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -346,20 +346,21 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
346346
if (PARAM.inp.deepks_scf)
347347
{
348348
const std::vector<std::vector<std::complex<double>>>& dm_k = dm->get_DMK_vector();
349-
std::vector<torch::Tensor> descriptor;
350-
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
351-
DeePKS_domain::cal_descriptor(ucell.nat, ld.inlmax, ld.inl_l, ld.pdm, descriptor, ld.des_per_atom);
352-
DeePKS_domain::cal_gedm(ucell.nat,
353-
ld.lmaxd,
354-
ld.nmaxd,
355-
ld.inlmax,
356-
ld.des_per_atom,
357-
ld.inl_l,
358-
descriptor,
359-
ld.pdm,
360-
ld.model_deepks,
361-
ld.gedm,
362-
ld.E_delta);
349+
350+
// These calculations have been done in LCAO_Deepks_Interface in after_scf
351+
// std::vector<torch::Tensor> descriptor;
352+
// DeePKS_domain::cal_descriptor(ucell.nat, ld.inlmax, ld.inl_l, ld.pdm, descriptor, ld.des_per_atom);
353+
// DeePKS_domain::cal_edelta_gedm(ucell.nat,
354+
// ld.lmaxd,
355+
// ld.nmaxd,
356+
// ld.inlmax,
357+
// ld.des_per_atom,
358+
// ld.inl_l,
359+
// descriptor,
360+
// ld.pdm,
361+
// ld.model_deepks,
362+
// ld.gedm,
363+
// ld.E_delta);
363364

364365
DeePKS_domain::cal_f_delta<std::complex<double>>(dm_k,
365366
ucell,

source/module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"
1414

1515
#include <vector>
16+
17+
#ifdef __DEEPKS
18+
#include "module_hamilt_lcao/module_deepks/LCAO_deepks.h"
19+
#endif
20+
1621
#ifdef __EXX
1722
#include "module_ri/Exx_LRI.h"
1823
#endif

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
186186
this->ld->pdm,
187187
descriptor,
188188
this->ld->des_per_atom);
189-
DeePKS_domain::cal_gedm(this->ucell->nat,
189+
DeePKS_domain::cal_edelta_gedm(this->ucell->nat,
190190
this->ld->lmaxd,
191191
this->ld->nmaxd,
192192
inlmax,

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class LCAO_Deepks
8080
bool init_pdm = false; // for DeePKS NSCF calculation, set init_pdm to skip the calculation of pdm in SCF iteration
8181

8282
// deep neural network module that provides corrected Hamiltonian term and
83-
// related derivatives. Used in cal_gedm.
83+
// related derivatives. Used in cal_edelta_gedm.
8484
torch::jit::script::Module model_deepks;
8585

8686
// saves <phi(0)|alpha(R)> and its derivatives

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
3939
// These variables are frequently used in the following code
4040
const int inlmax = orb.Alpha[0].getTotal_nchi() * nat;
4141
const int lmaxd = orb.get_lmax_d();
42+
const int nmaxd = ld->nmaxd;
4243

4344
const int des_per_atom = ld->des_per_atom;
4445
const int* inl_l = ld->inl_l;
@@ -49,11 +50,56 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
4950
bool init_pdm = ld->init_pdm;
5051
double E_delta = ld->E_delta;
5152
double e_delta_band = ld->e_delta_band;
52-
double** gedm = ld->gedm;
5353

5454
const int my_rank = GlobalV::MY_RANK;
5555
const int nspin = PARAM.inp.nspin;
5656

57+
// Note : update PDM and all other quantities with the current dm
58+
// DeePKS PDM and descriptor
59+
if (PARAM.inp.deepks_out_labels || PARAM.inp.deepks_scf)
60+
{
61+
// this part is for integrated test of deepks
62+
// so it is printed no matter even if deepks_out_labels is not used
63+
DeePKS_domain::cal_pdm<TK>
64+
(init_pdm, inlmax, lmaxd, inl_l, inl_index, dm, phialpha, ucell, orb, GridD, *ParaV, pdm);
65+
66+
DeePKS_domain::check_pdm(inlmax, inl_l, pdm); // print out the projected dm for NSCF calculaiton
67+
68+
std::vector<torch::Tensor> descriptor;
69+
DeePKS_domain::cal_descriptor(nat, inlmax, inl_l, pdm, descriptor,
70+
des_per_atom); // final descriptor
71+
DeePKS_domain::check_descriptor(inlmax, des_per_atom, inl_l, ucell, PARAM.globalv.global_out_dir, descriptor);
72+
73+
if (PARAM.inp.deepks_out_labels)
74+
{
75+
LCAO_deepks_io::save_npy_d(nat,
76+
des_per_atom,
77+
inlmax,
78+
inl_l,
79+
PARAM.inp.deepks_equiv,
80+
descriptor,
81+
PARAM.globalv.global_out_dir,
82+
GlobalV::MY_RANK); // libnpy needed
83+
}
84+
85+
if (PARAM.inp.deepks_scf)
86+
{
87+
// update E_delta and gedm
88+
// new gedm is also useful in cal_f_delta, so it should be ld->gedm
89+
DeePKS_domain::cal_edelta_gedm(nat,
90+
lmaxd,
91+
nmaxd,
92+
inlmax,
93+
des_per_atom,
94+
inl_l,
95+
descriptor,
96+
pdm,
97+
ld->model_deepks,
98+
ld->gedm,
99+
E_delta);
100+
}
101+
}
102+
57103
// Used for deepks_bandgap == 1 and deepks_v_delta > 0
58104
std::vector<std::vector<TK>>* h_delta = nullptr;
59105
if constexpr (std::is_same<TK, double>::value)
@@ -65,7 +111,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
65111
h_delta = &ld->H_V_delta_k;
66112
}
67113

68-
// calculating deepks correction to bandgap and save the results
114+
// calculating deepks correction and save the results
69115
if (PARAM.inp.deepks_out_labels)
70116
{
71117
// Used for deepks_scf == 1
@@ -317,38 +363,6 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
317363

318364
} // end deepks_out_labels
319365

320-
// DeePKS PDM and descriptor
321-
if (PARAM.inp.deepks_out_labels || PARAM.inp.deepks_scf)
322-
{
323-
// this part is for integrated test of deepks
324-
// so it is printed no matter even if deepks_out_labels is not used
325-
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
326-
if (!PARAM.inp.deepks_scf)
327-
{
328-
DeePKS_domain::cal_pdm<
329-
TK>(init_pdm, inlmax, lmaxd, inl_l, inl_index, dm, phialpha, ucell, orb, GridD, *ParaV, pdm);
330-
}
331-
332-
DeePKS_domain::check_pdm(inlmax, inl_l, pdm); // print out the projected dm for NSCF calculaiton
333-
334-
std::vector<torch::Tensor> descriptor;
335-
DeePKS_domain::cal_descriptor(nat, inlmax, inl_l, pdm, descriptor,
336-
des_per_atom); // final descriptor
337-
DeePKS_domain::check_descriptor(inlmax, des_per_atom, inl_l, ucell, PARAM.globalv.global_out_dir, descriptor);
338-
339-
if (PARAM.inp.deepks_out_labels)
340-
{
341-
LCAO_deepks_io::save_npy_d(nat,
342-
des_per_atom,
343-
inlmax,
344-
inl_l,
345-
PARAM.inp.deepks_equiv,
346-
descriptor,
347-
PARAM.globalv.global_out_dir,
348-
GlobalV::MY_RANK); // libnpy needed
349-
}
350-
}
351-
352366
/// print out deepks information to the screen
353367
if (PARAM.inp.deepks_scf)
354368
{
@@ -361,7 +375,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
361375
{
362376
LCAO_deepks_io::print_dm(nks, PARAM.globalv.nlocal, ParaV->nrow, dm->get_DMK_vector());
363377

364-
DeePKS_domain::check_gedm(inlmax, inl_l, gedm);
378+
DeePKS_domain::check_gedm(inlmax, inl_l, ld->gedm);
365379

366380
std::ofstream ofs("E_delta_bands.dat");
367381
ofs << std::setprecision(10) << e_delta_band;

source/module_hamilt_lcao/module_deepks/deepks_basic.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ inline void generate_py_files(const int lmaxd, const int nmaxd, const std::strin
7575
return;
7676
}
7777

78-
std::ofstream ofs("cal_gedm.py");
78+
std::ofstream ofs("cal_edelta_gedm.py");
7979
ofs << "import torch" << std::endl;
8080
ofs << "import numpy as np" << std::endl << std::endl;
8181
ofs << "import sys" << std::endl;
@@ -121,7 +121,7 @@ inline void generate_py_files(const int lmaxd, const int nmaxd, const std::strin
121121
}
122122
}
123123

124-
void DeePKS_domain::cal_gedm_equiv(const int nat,
124+
void DeePKS_domain::cal_edelta_gedm_equiv(const int nat,
125125
const int lmaxd,
126126
const int nmaxd,
127127
const int inlmax,
@@ -131,7 +131,7 @@ void DeePKS_domain::cal_gedm_equiv(const int nat,
131131
double** gedm,
132132
double& E_delta)
133133
{
134-
ModuleBase::TITLE("DeePKS_domain", "cal_gedm_equiv");
134+
ModuleBase::TITLE("DeePKS_domain", "cal_edelta_gedm_equiv");
135135

136136
LCAO_deepks_io::save_npy_d(nat,
137137
des_per_atom,
@@ -146,7 +146,7 @@ void DeePKS_domain::cal_gedm_equiv(const int nat,
146146

147147
if (GlobalV::MY_RANK == 0)
148148
{
149-
std::string cmd = "python cal_gedm.py " + PARAM.inp.deepks_model;
149+
std::string cmd = "python cal_edelta_gedm.py " + PARAM.inp.deepks_model;
150150
int stat = std::system(cmd.c_str());
151151
assert(stat == 0);
152152
}
@@ -155,13 +155,13 @@ void DeePKS_domain::cal_gedm_equiv(const int nat,
155155

156156
LCAO_deepks_io::load_npy_gedm(nat, des_per_atom, gedm, E_delta, GlobalV::MY_RANK);
157157

158-
std::string cmd = "rm -f cal_gedm.py basis.yaml ec.npy gedm.npy";
158+
std::string cmd = "rm -f cal_edelta_gedm.py basis.yaml ec.npy gedm.npy";
159159
std::system(cmd.c_str());
160160
}
161161

162162
// obtain from the machine learning model dE_delta/dDescriptor
163163
// E_delta is also calculated here
164-
void DeePKS_domain::cal_gedm(const int nat,
164+
void DeePKS_domain::cal_edelta_gedm(const int nat,
165165
const int lmaxd,
166166
const int nmaxd,
167167
const int inlmax,
@@ -175,10 +175,10 @@ void DeePKS_domain::cal_gedm(const int nat,
175175
{
176176
if (PARAM.inp.deepks_equiv)
177177
{
178-
DeePKS_domain::cal_gedm_equiv(nat, lmaxd, nmaxd, inlmax, des_per_atom, inl_l, descriptor, gedm, E_delta);
178+
DeePKS_domain::cal_edelta_gedm_equiv(nat, lmaxd, nmaxd, inlmax, des_per_atom, inl_l, descriptor, gedm, E_delta);
179179
return;
180180
}
181-
ModuleBase::TITLE("DeePKS_domain", "cal_gedm");
181+
ModuleBase::TITLE("DeePKS_domain", "cal_edelta_gedm");
182182

183183
// forward
184184
std::vector<torch::jit::IValue> inputs;

source/module_hamilt_lcao/module_deepks/deepks_basic.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ namespace DeePKS_domain
1818
// The file contains 2 subroutines:
1919
// 1. load_model : loads model for applying V_delta
2020
// 2. cal_gevdm : d(des)/d(pdm), calculated using torch::autograd::grad
21-
// 3. cal_gedm : calculates d(E_delta)/d(pdm)
21+
// 3. cal_edelta_gedm : calculates E_delta and d(E_delta)/d(pdm)
2222
// this is the term V(D) that enters the expression H_V_delta = |alpha>V(D)<alpha|
2323
// caculated using torch::autograd::grad
2424
// 4. check_gedm : prints gedm for checking
25-
// 5. cal_gedm_equiv : calculates d(E_delta)/d(pdm) for equivariant version
25+
// 5. cal_edelta_gedm_equiv : calculates E_delta and d(E_delta)/d(pdm) for equivariant version
2626

2727
// load the trained neural network models
2828
void load_model(const std::string& model_file, torch::jit::script::Module& model);
@@ -35,7 +35,7 @@ void cal_gevdm(const int nat,
3535
std::vector<torch::Tensor>& gevdm);
3636

3737
/// calculate partial of energy correction to descriptors
38-
void cal_gedm(const int nat,
38+
void cal_edelta_gedm(const int nat,
3939
const int lmaxd,
4040
const int nmaxd,
4141
const int inlmax,
@@ -47,7 +47,7 @@ void cal_gedm(const int nat,
4747
double** gedm,
4848
double& E_delta);
4949
void check_gedm(const int inlmax, const int* inl_l, double** gedm);
50-
void cal_gedm_equiv(const int nat,
50+
void cal_edelta_gedm_equiv(const int nat,
5151
const int lmaxd,
5252
const int nmaxd,
5353
const int inlmax,

source/module_hamilt_lcao/module_deepks/test/LCAO_deepks_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ void test_deepks::check_edelta(std::vector<torch::Tensor>& descriptor)
372372
{
373373
ld.allocate_V_delta(ucell.nat, kv.nkstot);
374374
}
375-
DeePKS_domain::cal_gedm(ucell.nat,
375+
DeePKS_domain::cal_edelta_gedm(ucell.nat,
376376
this->ld.lmaxd,
377377
this->ld.nmaxd,
378378
this->ld.inlmax,
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
etotref -465.9986234576679
2-
etotperatomref -155.3328744859
3-
totalforceref 5.535106
4-
totalstressref 1.522353
5-
totaldes 2.163703
6-
deepks_e_dm -57.8857180593137
7-
deepks_f_label 19.09559844689178
8-
deepks_s_label 19.250590727951906
9-
totaltimeref 12.58
1+
etotref -465.9986234579913
2+
etotperatomref -155.3328744860
3+
totalforceref 5.535112
4+
totalstressref 1.522354
5+
totaldes 2.163682
6+
deepks_e_dm -57.88576364957592
7+
deepks_f_label 19.095631983991726
8+
deepks_s_label 19.250613228828858
9+
totaltimeref 22.06

0 commit comments

Comments
 (0)