Skip to content

Commit 8407ee9

Browse files
authored
Move force/stress precalc functions into new files and remove the global dependence. (#5824)
1 parent 39aab7a commit 8407ee9

39 files changed

+767
-501
lines changed

source/Makefile.Objects

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ OBJS_CELL=atom_pseudo.o\
192192

193193
OBJS_DEEPKS=LCAO_deepks.o\
194194
deepks_force.o\
195+
deepks_fpre.o\
196+
deepks_spre.o\
195197
deepks_descriptor.o\
196198
deepks_orbital.o\
197199
deepks_orbpre.o\
@@ -203,10 +205,7 @@ OBJS_DEEPKS=LCAO_deepks.o\
203205
LCAO_deepks_torch.o\
204206
LCAO_deepks_vdelta.o\
205207
LCAO_deepks_interface.o\
206-
cal_gdmx.o\
207-
cal_gdmepsl.o\
208208
cal_gedm.o\
209-
cal_gvx.o\
210209

211210

212211
OBJS_ELECSTAT=elecstate.o\

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,18 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
519519
const std::vector<std::vector<double>>& dm_gamma
520520
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();
521521

522-
GlobalC::ld
523-
.cal_gdmx(dm_gamma, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
522+
DeePKS_domain::cal_gdmx(GlobalC::ld.lmaxd,
523+
GlobalC::ld.inlmax,
524+
kv.get_nks(),
525+
kv.kvec_d,
526+
GlobalC::ld.phialpha,
527+
GlobalC::ld.inl_index,
528+
dm_gamma,
529+
ucell,
530+
orb,
531+
pv,
532+
gd,
533+
gdmx);
524534
}
525535
else
526536
{
@@ -529,20 +539,34 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
529539
->get_DM()
530540
->get_DMK_vector();
531541

532-
GlobalC::ld.cal_gdmx(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmx);
533-
}
534-
if (PARAM.inp.deepks_out_unittest)
535-
{
536-
GlobalC::ld.check_gdmx(ucell.nat, gdmx);
542+
DeePKS_domain::cal_gdmx(GlobalC::ld.lmaxd,
543+
GlobalC::ld.inlmax,
544+
kv.get_nks(),
545+
kv.kvec_d,
546+
GlobalC::ld.phialpha,
547+
GlobalC::ld.inl_index,
548+
dm_k,
549+
ucell,
550+
orb,
551+
pv,
552+
gd,
553+
gdmx);
537554
}
538555
std::vector<torch::Tensor> gevdm;
539556
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
540557
torch::Tensor gvx;
541-
GlobalC::ld.cal_gvx(ucell.nat, gevdm, gdmx, gvx);
558+
DeePKS_domain::cal_gvx(ucell.nat,
559+
GlobalC::ld.inlmax,
560+
GlobalC::ld.des_per_atom,
561+
GlobalC::ld.inl_l,
562+
gevdm,
563+
gdmx,
564+
gvx);
542565

543566
if (PARAM.inp.deepks_out_unittest)
544567
{
545-
GlobalC::ld.check_gvx(ucell.nat, gvx);
568+
DeePKS_domain::check_gdmx(gdmx);
569+
DeePKS_domain::check_gvx(gvx);
546570
}
547571

548572
LCAO_deepks_io::save_npy_gvx(ucell.nat,
@@ -751,14 +775,18 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
751775
const std::vector<std::vector<double>>& dm_gamma
752776
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();
753777

754-
GlobalC::ld.cal_gdmepsl(dm_gamma,
755-
ucell,
756-
orb,
757-
gd,
758-
kv.get_nks(),
759-
kv.kvec_d,
760-
GlobalC::ld.phialpha,
761-
gdmepsl);
778+
DeePKS_domain::cal_gdmepsl(GlobalC::ld.lmaxd,
779+
GlobalC::ld.inlmax,
780+
kv.get_nks(),
781+
kv.kvec_d,
782+
GlobalC::ld.phialpha,
783+
GlobalC::ld.inl_index,
784+
dm_gamma,
785+
ucell,
786+
orb,
787+
pv,
788+
gd,
789+
gdmepsl);
762790
}
763791
else
764792
{
@@ -767,18 +795,36 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
767795
->get_DM()
768796
->get_DMK_vector();
769797

770-
GlobalC::ld
771-
.cal_gdmepsl(dm_k, ucell, orb, gd, kv.get_nks(), kv.kvec_d, GlobalC::ld.phialpha, gdmepsl);
772-
}
773-
if (PARAM.inp.deepks_out_unittest)
774-
{
775-
GlobalC::ld.check_gdmepsl(gdmepsl);
798+
DeePKS_domain::cal_gdmepsl(GlobalC::ld.lmaxd,
799+
GlobalC::ld.inlmax,
800+
kv.get_nks(),
801+
kv.kvec_d,
802+
GlobalC::ld.phialpha,
803+
GlobalC::ld.inl_index,
804+
dm_k,
805+
ucell,
806+
orb,
807+
pv,
808+
gd,
809+
gdmepsl);
776810
}
777811

778812
std::vector<torch::Tensor> gevdm;
779813
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
780814
torch::Tensor gvepsl;
781-
GlobalC::ld.cal_gvepsl(ucell.nat, gevdm, gdmepsl, gvepsl);
815+
DeePKS_domain::cal_gvepsl(ucell.nat,
816+
GlobalC::ld.inlmax,
817+
GlobalC::ld.des_per_atom,
818+
GlobalC::ld.inl_l,
819+
gevdm,
820+
gdmepsl,
821+
gvepsl);
822+
823+
if (PARAM.inp.deepks_out_unittest)
824+
{
825+
DeePKS_domain::check_gdmepsl(gdmepsl);
826+
DeePKS_domain::check_gvepsl(gvepsl);
827+
}
782828

783829
LCAO_deepks_io::save_npy_gvepsl(ucell.nat,
784830
GlobalC::ld.des_per_atom,

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ if(ENABLE_DEEPKS)
33
LCAO_deepks.cpp
44
deepks_descriptor.cpp
55
deepks_force.cpp
6+
deepks_fpre.cpp
7+
deepks_spre.cpp
68
deepks_orbital.cpp
79
deepks_orbpre.cpp
810
deepks_vdpre.cpp
@@ -13,10 +15,7 @@ if(ENABLE_DEEPKS)
1315
LCAO_deepks_torch.cpp
1416
LCAO_deepks_vdelta.cpp
1517
LCAO_deepks_interface.cpp
16-
cal_gdmx.cpp
17-
cal_gdmepsl.cpp
18-
cal_gedm.cpp
19-
cal_gvx.cpp
18+
cal_gedm.cpp
2019
)
2120

2221
add_library(

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 6 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
#include "deepks_descriptor.h"
77
#include "deepks_force.h"
8+
#include "deepks_fpre.h"
89
#include "deepks_hmat.h"
910
#include "deepks_orbital.h"
1011
#include "deepks_orbpre.h"
12+
#include "deepks_spre.h"
1113
#include "deepks_vdpre.h"
1214
#include "module_base/complexmatrix.h"
1315
#include "module_base/intarray.h"
@@ -122,9 +124,9 @@ class LCAO_Deepks
122124
// \sum_L{Nchi(L)*(2L+1)}
123125
int des_per_atom;
124126

125-
ModuleBase::IntArray* alpha_index;
126-
ModuleBase::IntArray* inl_index; // caoyu add 2021-05-07
127-
int* inl_l; // inl_l[inl_index] = l of descriptor with inl_index
127+
ModuleBase::IntArray* alpha_index; // seems not used in the code
128+
ModuleBase::IntArray* inl_index; // caoyu add 2021-05-07
129+
int* inl_l; // inl_l[inl_index] = l of descriptor with inl_index
128130

129131
// HR status,
130132
// true : HR should be calculated
@@ -212,13 +214,10 @@ class LCAO_Deepks
212214
// It also contains subroutines for printing pdm and gdmx
213215
// for checking purpose
214216

215-
// There are 4 subroutines in this file:
217+
// There are 2 subroutines in this file:
216218
// 1. cal_projected_DM, which is used for calculating pdm
217219
// 2. check_projected_dm, which prints pdm to descriptor.dat
218220

219-
// 3. cal_gdmx, calculating gdmx (and optionally gdmepsl for stress)
220-
// 4. check_gdmx, which prints gdmx to a series of .dat files
221-
222221
public:
223222
/**
224223
* @brief calculate projected density matrix:
@@ -237,34 +236,6 @@ class LCAO_Deepks
237236

238237
void check_projected_dm();
239238

240-
// calculate the gradient of pdm with regard to atomic positions
241-
// d/dX D_{Inl,mm'}
242-
template <typename TK>
243-
void cal_gdmx( // const ModuleBase::matrix& dm,
244-
const std::vector<std::vector<TK>>& dm,
245-
const UnitCell& ucell,
246-
const LCAO_Orbitals& orb,
247-
const Grid_Driver& GridD,
248-
const int nks,
249-
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
250-
std::vector<hamilt::HContainer<double>*> phialpha,
251-
torch::Tensor& gdmx);
252-
253-
void check_gdmx(const int nat, const torch::Tensor& gdmx);
254-
255-
template <typename TK>
256-
void cal_gdmepsl( // const ModuleBase::matrix& dm,
257-
const std::vector<std::vector<TK>>& dm,
258-
const UnitCell& ucell,
259-
const LCAO_Orbitals& orb,
260-
const Grid_Driver& GridD,
261-
const int nks,
262-
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
263-
std::vector<hamilt::HContainer<double>*> phialpha,
264-
torch::Tensor& gdmepsl);
265-
266-
void check_gdmepsl(const torch::Tensor& gdmepsl);
267-
268239
/**
269240
* @brief set init_pdm to skip the calculation of pdm in SCF iteration
270241
*/
@@ -310,14 +281,6 @@ class LCAO_Deepks
310281
// as well as subroutines that prints the results for checking
311282

312283
// The file contains 8 subroutines:
313-
// 3. cal_gvx : gvx is used for training with force label, which is gradient of descriptors,
314-
// calculated by d(des)/dX = d(pdm)/dX * d(des)/d(pdm) = gdmx * gvdm
315-
// using einsum
316-
// 4. check_gvx : prints gvx into gvx.dat for checking
317-
// 5. cal_gvepsl : gvepsl is used for training with stress label, which is derivative of
318-
// descriptors wrt strain tensor, calculated by
319-
// d(des)/d\epsilon_{ab} = d(pdm)/d\epsilon_{ab} * d(des)/d(pdm) = gdmepsl * gvdm
320-
// using einsum
321284
// 6. cal_gevdm : d(des)/d(pdm)
322285
// calculated using torch::autograd::grad
323286
// 7. load_model : loads model for applying V_delta
@@ -327,24 +290,6 @@ class LCAO_Deepks
327290
// 9. check_gedm : prints gedm for checking
328291

329292
public:
330-
/// calculates gradient of descriptors w.r.t atomic positions
331-
///----------------------------------------------------
332-
/// m, n: 2*l+1
333-
/// v: eigenvalues of dm , 2*l+1
334-
/// a,b: natom
335-
/// - (a: the center of descriptor orbitals
336-
/// - b: the atoms whose force being calculated)
337-
/// gvdm*gdmx->gvx
338-
///----------------------------------------------------
339-
void cal_gvx(const int nat, const std::vector<torch::Tensor>& gevdm, const torch::Tensor& gdmx, torch::Tensor& gvx);
340-
void check_gvx(const int nat, const torch::Tensor& gvx);
341-
342-
// for stress
343-
void cal_gvepsl(const int nat,
344-
const std::vector<torch::Tensor>& gevdm,
345-
const torch::Tensor& gdmepsl,
346-
torch::Tensor& gvepsl);
347-
348293
// load the trained neural network model
349294
void load_model(const std::string& model_file);
350295

source/module_hamilt_lcao/module_deepks/LCAO_deepks_torch.cpp

Lines changed: 1 addition & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
// as well as subroutines that prints the results for checking
44

55
// The file contains 3 subroutines:
6-
// cal_gvepsl : gvepsl is used for training with stress label, which is derivative of
7-
// descriptors wrt strain tensor, calculated by
8-
// d(des)/d\epsilon_{ab} = d(pdm)/d\epsilon_{ab} * d(des)/d(pdm) = gdmepsl * gvdm
9-
// using einsum
6+
107
// cal_gevdm : d(des)/d(pdm)
118
// calculated using torch::autograd::grad
129
// load_model : loads model for applying V_delta
@@ -22,72 +19,6 @@
2219
#include "module_hamilt_lcao/module_hcontainer/atom_pair.h"
2320
#include "module_parameter/parameter.h"
2421

25-
// calculates stress of descriptors from gradient of projected density matrices
26-
// gv_epsl:d(d)/d\epsilon_{\alpha\beta}, [natom][6][des_per_atom]
27-
void LCAO_Deepks::cal_gvepsl(const int nat,
28-
const std::vector<torch::Tensor>& gevdm,
29-
const torch::Tensor& gdmepsl,
30-
torch::Tensor& gvepsl)
31-
{
32-
ModuleBase::TITLE("LCAO_Deepks", "cal_gvepsl");
33-
// dD/d\epsilon_{\alpha\beta}, tensor vector form of gdmepsl
34-
std::vector<torch::Tensor> gdmepsl_vector;
35-
auto accessor = gdmepsl.accessor<double, 4>();
36-
if (GlobalV::MY_RANK == 0)
37-
{
38-
// make gdmx as tensor
39-
int nlmax = this->inlmax / nat;
40-
for (int nl = 0; nl < nlmax; ++nl)
41-
{
42-
std::vector<torch::Tensor> bmmv;
43-
for (int i = 0; i < 6; ++i)
44-
{
45-
std::vector<torch::Tensor> ammv;
46-
for (int iat = 0; iat < nat; ++iat)
47-
{
48-
int inl = iat * nlmax + nl;
49-
int nm = 2 * this->inl_l[inl] + 1;
50-
std::vector<double> mmv;
51-
for (int m1 = 0; m1 < nm; ++m1)
52-
{
53-
for (int m2 = 0; m2 < nm; ++m2)
54-
{
55-
mmv.push_back(accessor[i][inl][m1][m2]);
56-
}
57-
} // nm^2
58-
torch::Tensor mm
59-
= torch::tensor(mmv, torch::TensorOptions().dtype(torch::kFloat64)).reshape({nm, nm}); // nm*nm
60-
ammv.push_back(mm);
61-
}
62-
torch::Tensor bmm = torch::stack(ammv, 0); // nat*nm*nm
63-
bmmv.push_back(bmm);
64-
}
65-
gdmepsl_vector.push_back(torch::stack(bmmv, 0)); // nbt*3*nat*nm*nm
66-
}
67-
assert(gdmepsl_vector.size() == nlmax);
68-
69-
// einsum for each inl:
70-
// gdmepsl_vector : b:npol * a:inl(projector) * m:nm * n:nm
71-
// gevdm : a:inl * v:nm (descriptor) * m:nm (pdm, dim1) * n:nm
72-
// (pdm, dim2) gvepsl_vector : b:npol * a:inl(projector) *
73-
// m:nm(descriptor)
74-
std::vector<torch::Tensor> gvepsl_vector;
75-
for (int nl = 0; nl < nlmax; ++nl)
76-
{
77-
gvepsl_vector.push_back(at::einsum("bamn, avmn->bav", {gdmepsl_vector[nl], gevdm[nl]}));
78-
}
79-
80-
// cat nv-> \sum_nl(nv) = \sum_nl(nm_nl)=des_per_atom
81-
// concatenate index a(inl) and m(nm)
82-
gvepsl = torch::cat(gvepsl_vector, -1);
83-
assert(gvepsl.size(0) == 6);
84-
assert(gvepsl.size(1) == nat);
85-
assert(gvepsl.size(2) == this->des_per_atom);
86-
}
87-
88-
return;
89-
}
90-
9122
// d(Descriptor) / d(projected density matrix)
9223
// Dimension is different for each inl, so there's a vector of tensors
9324
void LCAO_Deepks::cal_gevdm(const int nat, std::vector<torch::Tensor>& gevdm)

0 commit comments

Comments
 (0)