Skip to content

Commit d7b76fc

Browse files
authored
Refactor: Remove some redundant variables and global dependence in DeePKS. (#5791)
* Remove cal_orbital_precalc() and corresponding temporary variables from LCAO_Deepks; Partially change cal_gevdm() for future refactor; Remove the double pointer pdm and use tensor vector for replacement. * clang-format adjustment.
1 parent cd48c62 commit d7b76fc

17 files changed

+749
-763
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ OBJS_DEEPKS=LCAO_deepks.o\
200200
LCAO_deepks_vdelta.o\
201201
deepks_hmat.o\
202202
LCAO_deepks_interface.o\
203-
orbital_precalc.o\
203+
deepks_orbpre.o\
204204
cal_gdmx.o\
205205
cal_gedm.o\
206206
cal_gvx.o\

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
// new
88
#include "module_base/timer.h"
99
#include "module_cell/module_neighbor/sltk_grid_driver.h"
10+
#include "module_elecstate/elecstate_lcao.h"
1011
#include "module_elecstate/potentials/efield.h" // liuyu add 2022-05-18
1112
#include "module_elecstate/potentials/gatefield.h" // liuyu add 2022-09-13
1213
#include "module_hamilt_general/module_surchem/surchem.h" //sunml add 2022-08-10
1314
#include "module_hamilt_general/module_vdw/vdw.h"
1415
#include "module_parameter/parameter.h"
15-
#include "module_elecstate/elecstate_lcao.h"
1616
#ifdef __DEEPKS
1717
#include "module_hamilt_lcao/module_deepks/LCAO_deepks.h" //caoyu add for deepks 2021-06-03
1818
#include "module_hamilt_lcao/module_deepks/LCAO_deepks_io.h" // mohan add 2024-07-22
@@ -540,7 +540,9 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
540540
{
541541
GlobalC::ld.check_gdmx(ucell.nat);
542542
}
543-
GlobalC::ld.cal_gvx(ucell.nat);
543+
std::vector<torch::Tensor> gevdm;
544+
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
545+
GlobalC::ld.cal_gvx(ucell.nat, gevdm);
544546

545547
if (PARAM.inp.deepks_out_unittest)
546548
{
@@ -758,7 +760,9 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
758760

759761
if (!PARAM.inp.deepks_equiv) // training with stress label not supported by equivariant version now
760762
{
761-
GlobalC::ld.cal_gvepsl(ucell.nat);
763+
std::vector<torch::Tensor> gevdm;
764+
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
765+
GlobalC::ld.cal_gvepsl(ucell.nat, gevdm);
762766

763767
LCAO_deepks_io::save_npy_gvepsl(ucell.nat,
764768
GlobalC::ld.des_per_atom,

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ if(ENABLE_DEEPKS)
1111
LCAO_deepks_vdelta.cpp
1212
deepks_hmat.cpp
1313
LCAO_deepks_interface.cpp
14-
orbital_precalc.cpp
14+
deepks_orbpre.cpp
1515
cal_gdmx.cpp
1616
cal_gedm.cpp
1717
cal_gvx.cpp

source/module_hamilt_lcao/module_deepks/LCAO_deepks.cpp

Lines changed: 26 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,7 @@ LCAO_Deepks::~LCAO_Deepks()
4545
delete[] inl_l;
4646

4747
//=======1. to use deepks, pdm is required==========
48-
// delete pdm**
49-
for (int inl = 0; inl < this->inlmax; inl++)
50-
{
51-
delete[] pdm[inl];
52-
}
53-
delete[] pdm;
48+
pdm.clear();
5449
//=======2. "deepks_scf" part==========
5550
// if (PARAM.inp.deepks_scf)
5651
if (gedm)
@@ -100,12 +95,33 @@ void LCAO_Deepks::init(const LCAO_Orbitals& orb,
10095

10196
int pdm_size = 0;
10297
this->inlmax = tot_inl;
98+
this->pdm.resize(this->inlmax);
99+
100+
// cal n(descriptor) per atom , related to Lmax, nchi(L) and m. (not total_nchi!)
101+
if (!PARAM.inp.deepks_equiv)
102+
{
103+
this->des_per_atom = 0; // mohan add 2021-04-21
104+
for (int l = 0; l <= this->lmaxd; l++)
105+
{
106+
this->des_per_atom += orb.Alpha[0].getNchi(l) * (2 * l + 1);
107+
}
108+
this->n_descriptor = nat * this->des_per_atom;
109+
110+
this->init_index(ntype, nat, na, tot_inl, orb);
111+
}
112+
103113
if (!PARAM.inp.deepks_equiv)
104114
{
105115
GlobalV::ofs_running << " total basis (all atoms) for descriptor = " << std::endl;
106116

107-
// init pdm**
108-
pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1);
117+
// init pdm
118+
for (int inl = 0; inl < this->inlmax; inl++)
119+
{
120+
int nm = 2 * inl_l[inl] + 1;
121+
pdm_size += nm * nm;
122+
this->pdm[inl] = torch::zeros({nm, nm}, torch::kFloat64);
123+
// this->pdm[inl].requires_grad_(true);
124+
}
109125
}
110126
else
111127
{
@@ -116,26 +132,10 @@ void LCAO_Deepks::init(const LCAO_Orbitals& orb,
116132
pdm_size = pdm_size * pdm_size;
117133
this->des_per_atom = pdm_size;
118134
GlobalV::ofs_running << " Equivariant version, size of pdm matrices : " << pdm_size << std::endl;
119-
}
120-
121-
this->pdm = new double*[this->inlmax];
122-
for (int inl = 0; inl < this->inlmax; inl++)
123-
{
124-
this->pdm[inl] = new double[pdm_size];
125-
ModuleBase::GlobalFunc::ZEROS(this->pdm[inl], pdm_size);
126-
}
127-
128-
// cal n(descriptor) per atom , related to Lmax, nchi(L) and m. (not total_nchi!)
129-
if (!PARAM.inp.deepks_equiv)
130-
{
131-
this->des_per_atom = 0; // mohan add 2021-04-21
132-
for (int l = 0; l <= this->lmaxd; l++)
135+
for (int inl = 0; inl < this->inlmax; inl++)
133136
{
134-
this->des_per_atom += orb.Alpha[0].getNchi(l) * (2 * l + 1);
137+
this->pdm[inl] = torch::zeros({pdm_size}, torch::kFloat64);
135138
}
136-
this->n_descriptor = nat * this->des_per_atom;
137-
138-
this->init_index(ntype, nat, na, tot_inl, orb);
139139
}
140140

141141
this->pv = &pv_in;
@@ -343,41 +343,6 @@ void LCAO_Deepks::allocate_V_delta(const int nat, const int nks)
343343
return;
344344
}
345345

346-
void LCAO_Deepks::init_orbital_pdm_shell(const int nks)
347-
{
348-
349-
this->orbital_pdm_shell = new double**[nks];
350-
351-
for (int iks = 0; iks < nks; iks++)
352-
{
353-
this->orbital_pdm_shell[iks] = new double*[this->inlmax];
354-
for (int inl = 0; inl < this->inlmax; inl++)
355-
{
356-
this->orbital_pdm_shell[iks][inl] = new double[(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1)];
357-
ModuleBase::GlobalFunc::ZEROS(orbital_pdm_shell[iks][inl],
358-
(2 * this->lmaxd + 1) * (2 * this->lmaxd + 1));
359-
}
360-
361-
}
362-
363-
return;
364-
}
365-
366-
void LCAO_Deepks::del_orbital_pdm_shell(const int nks)
367-
{
368-
for (int iks = 0; iks < nks; iks++)
369-
{
370-
for (int inl = 0; inl < this->inlmax; inl++)
371-
{
372-
delete[] this->orbital_pdm_shell[iks][inl];
373-
}
374-
delete[] this->orbital_pdm_shell[iks];
375-
}
376-
delete[] this->orbital_pdm_shell;
377-
378-
return;
379-
}
380-
381346
void LCAO_Deepks::init_v_delta_pdm_shell(const int nks, const int nlocal)
382347
{
383348
const int mn_size = (2 * this->lmaxd + 1) * (2 * this->lmaxd + 1);

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "deepks_force.h"
77
#include "deepks_hmat.h"
88
#include "deepks_orbital.h"
9+
#include "deepks_orbpre.h"
910
#include "module_base/complexmatrix.h"
1011
#include "module_base/intarray.h"
1112
#include "module_base/matrix.h"
@@ -57,12 +58,6 @@ class LCAO_Deepks
5758
/// Correction term to Hamiltonian, for multi-k
5859
std::vector<std::vector<std::complex<double>>> H_V_delta_k;
5960

60-
// k index of HOMO for multi-k bandgap label. QO added 2022-01-24
61-
int h_ind = 0;
62-
63-
// k index of LUMO for multi-k bandgap label. QO added 2022-01-24
64-
int l_ind = 0;
65-
6661
// functions for hr status: 1. get value; 2. set value;
6762
int get_hr_cal()
6863
{
@@ -109,8 +104,9 @@ class LCAO_Deepks
109104
std::vector<hamilt::HContainer<double>*> phialpha;
110105

111106
// projected density matrix
112-
double** pdm; //[tot_Inl][2l+1][2l+1] caoyu modified 2021-05-07; if equivariant version: [nat][nlm*nlm]
113-
std::vector<torch::Tensor> pdm_tensor;
107+
// [tot_Inl][2l+1][2l+1], here l is corresponding to inl;
108+
// [nat][nlm*nlm] for equivariant version
109+
std::vector<torch::Tensor> pdm;
114110

115111
// descriptors
116112
std::vector<torch::Tensor> d_tensor;
@@ -138,17 +134,9 @@ class LCAO_Deepks
138134
// gvx:d(d)/dX, [natom][3][natom][des_per_atom]
139135
torch::Tensor gvx_tensor;
140136

141-
// d(d)/dD, autograd from torch::linalg::eigh
142-
std::vector<torch::Tensor> gevdm_vector;
143-
144137
// dD/dX, tensor form of gdmx
145138
std::vector<torch::Tensor> gdmr_vector;
146139

147-
// orbital_pdm_shell:[Inl,nm*nm]; \langle \phi_\mu|\alpha\rangle\langle\alpha|\phi_\nu\ranlge
148-
double*** orbital_pdm_shell;
149-
// orbital_precalc:[1,NAt,NDscrpt]; gvdm*orbital_pdm_shell
150-
torch::Tensor orbital_precalc_tensor;
151-
152140
// v_delta_pdm_shell[nks,nlocal,nlocal,Inl,nm*nm] = overlap * overlap
153141
double***** v_delta_pdm_shell;
154142
std::complex<double>***** v_delta_pdm_shell_complex; // for multi-k
@@ -223,12 +211,6 @@ class LCAO_Deepks
223211
private:
224212
// arrange index of descriptor in all atoms
225213
void init_index(const int ntype, const int nat, std::vector<int> na, const int tot_inl, const LCAO_Orbitals& orb);
226-
// data structure that saves <phi|alpha>
227-
void allocate_nlm(const int nat);
228-
229-
// for bandgap label calculation; QO added on 2022-1-7
230-
void init_orbital_pdm_shell(const int nks);
231-
void del_orbital_pdm_shell(const int nks);
232214

233215
// for v_delta label calculation; xinyuan added on 2023-2-22
234216
void init_v_delta_pdm_shell(const int nks, const int nlocal);
@@ -373,16 +355,16 @@ class LCAO_Deepks
373355
// descriptors wrt strain tensor, calculated by
374356
// d(des)/d\epsilon_{ab} = d(pdm)/d\epsilon_{ab} * d(des)/d(pdm) = gdm_epsl * gvdm
375357
// using einsum
376-
// 6. cal_gvdm : d(des)/d(pdm)
358+
// 6. cal_gevdm : d(des)/d(pdm)
377359
// calculated using torch::autograd::grad
378360
// 7. load_model : loads model for applying V_delta
379361
// 8. cal_gedm : calculates d(E_delta)/d(pdm)
380362
// this is the term V(D) that enters the expression H_V_delta = |alpha>V(D)<alpha|
381363
// caculated using torch::autograd::grad
382364
// 9. check_gedm : prints gedm for checking
383365
// 10. cal_orbital_precalc : orbital_precalc is usted for training with orbital label,
384-
// which equals gvdm * orbital_pdm_shell,
385-
// orbital_pdm_shell[Inl,nm*nm] = dm_hl * overlap * overlap
366+
// which equals gvdm * orbital_pdm,
367+
// orbital_pdm[nks,Inl,nm,nm] = dm_hl * overlap * overlap
386368
// 11. cal_v_delta_precalc : v_delta_precalc is used for training with v_delta label,
387369
// which equals gvdm * v_delta_pdm_shell,
388370
// v_delta_pdm_shell = overlap * overlap
@@ -408,11 +390,11 @@ class LCAO_Deepks
408390
/// - b: the atoms whose force being calculated)
409391
/// gvdm*gdmx->gvx
410392
///----------------------------------------------------
411-
void cal_gvx(const int nat);
393+
void cal_gvx(const int nat, const std::vector<torch::Tensor>& gevdm);
412394
void check_gvx(const int nat);
413395

414396
// for stress
415-
void cal_gvepsl(const int nat);
397+
void cal_gvepsl(const int nat, const std::vector<torch::Tensor>& gevdm);
416398

417399
// load the trained neural network model
418400
void load_model(const std::string& model_file);
@@ -423,20 +405,22 @@ class LCAO_Deepks
423405
void cal_gedm_equiv(const int nat);
424406

425407
// calculates orbital_precalc
426-
template <typename TK, typename TH>
427-
void cal_orbital_precalc(const std::vector<TH>& dm_hl,
428-
const int lmaxd,
429-
const int inlmax,
430-
const int nat,
431-
const int nks,
432-
const int* inl_l,
433-
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
434-
const std::vector<hamilt::HContainer<double>*> phialpha,
435-
const ModuleBase::IntArray* inl_index,
436-
const UnitCell& ucell,
437-
const LCAO_Orbitals& orb,
438-
const Parallel_Orbitals& pv,
439-
const Grid_Driver& GridD);
408+
// template <typename TK, typename TH>
409+
// void cal_orbital_precalc(const std::vector<TH>& dm_hl,
410+
// const int lmaxd,
411+
// const int inlmax,
412+
// const int nat,
413+
// const int nks,
414+
// const int* inl_l,
415+
// const std::vector<ModuleBase::Vector3<double>>& kvec_d,
416+
// const std::vector<hamilt::HContainer<double>*> phialpha,
417+
// const std::vector<torch::Tensor> gevdm,
418+
// const ModuleBase::IntArray* inl_index,
419+
// const UnitCell& ucell,
420+
// const LCAO_Orbitals& orb,
421+
// const Parallel_Orbitals& pv,
422+
// const Grid_Driver& GridD,
423+
// torch::Tensor& orbital_precalc);
440424

441425
// calculates v_delta_precalc
442426
template <typename TK>
@@ -466,11 +450,11 @@ class LCAO_Deepks
466450

467451
// prepare gevdm for outputting npy file
468452
void prepare_gevdm(const int nat, const LCAO_Orbitals& orb);
453+
void cal_gevdm(const int nat, std::vector<torch::Tensor>& gevdm);
469454
void check_vdp_gevdm(const int nat);
470455

471456
private:
472457
const Parallel_Orbitals* pv;
473-
void cal_gvdm(const int nat);
474458

475459
#ifdef __MPI
476460

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,32 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
113113

114114
std::vector<double> o_delta(nks, 0.0);
115115

116-
ld->cal_orbital_precalc<TK, TH>(dm_bandgap, ld->lmaxd, ld->inlmax, nat, nks, ld->inl_l, kvec_d, ld->phialpha, ld->inl_index, ucell, orb, *ParaV, GridD);
116+
// calculate and save orbital_precalc: [nks,NAt,NDscrpt]
117+
torch::Tensor orbital_precalc;
118+
std::vector<torch::Tensor> gevdm;
119+
ld->cal_gevdm(nat, gevdm);
120+
DeePKS_domain::cal_orbital_precalc<TK, TH>(dm_bandgap,
121+
ld->lmaxd,
122+
ld->inlmax,
123+
nat,
124+
nks,
125+
ld->inl_l,
126+
kvec_d,
127+
ld->phialpha,
128+
gevdm,
129+
ld->inl_index,
130+
ucell,
131+
orb,
132+
*ParaV,
133+
GridD,
134+
orbital_precalc);
117135
DeePKS_domain::cal_o_delta<TK, TH>(dm_bandgap, *h_delta, o_delta, *ParaV, nks);
118136

119137
// save obase and orbital_precalc
120138
LCAO_deepks_io::save_npy_orbital_precalc(nat,
121139
nks,
122140
ld->des_per_atom,
123-
ld->orbital_precalc_tensor,
141+
orbital_precalc,
124142
PARAM.globalv.global_out_dir,
125143
my_rank);
126144
const std::string file_obase = PARAM.globalv.global_out_dir + "deepks_obase.npy";
@@ -135,8 +153,8 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
135153
{
136154
const std::string file_obase = PARAM.globalv.global_out_dir + "deepks_obase.npy";
137155
LCAO_deepks_io::save_npy_o(o_tot, file_obase, nks, my_rank); // no scf, o_tot=o_base
138-
} // end deepks_scf == 0
139-
} // end bandgap label
156+
} // end deepks_scf == 0
157+
} // end bandgap label
140158

141159
// save H(R) matrix
142160
if (true) // should be modified later!

0 commit comments

Comments
 (0)