66#include " deepks_force.h"
77#include " deepks_hmat.h"
88#include " deepks_orbital.h"
9+ #include " deepks_orbpre.h"
910#include " module_base/complexmatrix.h"
1011#include " module_base/intarray.h"
1112#include " module_base/matrix.h"
@@ -57,12 +58,6 @@ class LCAO_Deepks
5758 // / Correction term to Hamiltonian, for multi-k
5859 std::vector<std::vector<std::complex <double >>> H_V_delta_k;
5960
60- // k index of HOMO for multi-k bandgap label. QO added 2022-01-24
61- int h_ind = 0 ;
62-
63- // k index of LUMO for multi-k bandgap label. QO added 2022-01-24
64- int l_ind = 0 ;
65-
6661 // functions for hr status: 1. get value; 2. set value;
6762 int get_hr_cal ()
6863 {
@@ -109,8 +104,9 @@ class LCAO_Deepks
109104 std::vector<hamilt::HContainer<double >*> phialpha;
110105
111106 // projected density matrix
112- double ** pdm; // [tot_Inl][2l+1][2l+1] caoyu modified 2021-05-07; if equivariant version: [nat][nlm*nlm]
113- std::vector<torch::Tensor> pdm_tensor;
107+ // [tot_Inl][2l+1][2l+1], here l is corresponding to inl;
108+ // [nat][nlm*nlm] for equivariant version
109+ std::vector<torch::Tensor> pdm;
114110
115111 // descriptors
116112 std::vector<torch::Tensor> d_tensor;
@@ -138,17 +134,9 @@ class LCAO_Deepks
138134 // gvx:d(d)/dX, [natom][3][natom][des_per_atom]
139135 torch::Tensor gvx_tensor;
140136
141- // d(d)/dD, autograd from torch::linalg::eigh
142- std::vector<torch::Tensor> gevdm_vector;
143-
144137 // dD/dX, tensor form of gdmx
145138 std::vector<torch::Tensor> gdmr_vector;
146139
147- // orbital_pdm_shell:[Inl,nm*nm]; \langle \phi_\mu|\alpha\rangle\langle\alpha|\phi_\nu\ranlge
148- double *** orbital_pdm_shell;
149- // orbital_precalc:[1,NAt,NDscrpt]; gvdm*orbital_pdm_shell
150- torch::Tensor orbital_precalc_tensor;
151-
152140 // v_delta_pdm_shell[nks,nlocal,nlocal,Inl,nm*nm] = overlap * overlap
153141 double ***** v_delta_pdm_shell;
154142 std::complex <double >***** v_delta_pdm_shell_complex; // for multi-k
@@ -223,12 +211,6 @@ class LCAO_Deepks
223211 private:
224212 // arrange index of descriptor in all atoms
225213 void init_index (const int ntype, const int nat, std::vector<int > na, const int tot_inl, const LCAO_Orbitals& orb);
226- // data structure that saves <phi|alpha>
227- void allocate_nlm (const int nat);
228-
229- // for bandgap label calculation; QO added on 2022-1-7
230- void init_orbital_pdm_shell (const int nks);
231- void del_orbital_pdm_shell (const int nks);
232214
233215 // for v_delta label calculation; xinyuan added on 2023-2-22
234216 void init_v_delta_pdm_shell (const int nks, const int nlocal);
@@ -373,16 +355,16 @@ class LCAO_Deepks
373355 // descriptors wrt strain tensor, calculated by
374356 // d(des)/d\epsilon_{ab} = d(pdm)/d\epsilon_{ab} * d(des)/d(pdm) = gdm_epsl * gvdm
375357 // using einsum
376- // 6. cal_gvdm : d(des)/d(pdm)
358+ // 6. cal_gevdm : d(des)/d(pdm)
377359 // calculated using torch::autograd::grad
378360 // 7. load_model : loads model for applying V_delta
379361 // 8. cal_gedm : calculates d(E_delta)/d(pdm)
380362 // this is the term V(D) that enters the expression H_V_delta = |alpha>V(D)<alpha|
381363 // caculated using torch::autograd::grad
382364 // 9. check_gedm : prints gedm for checking
383365 // 10. cal_orbital_precalc : orbital_precalc is usted for training with orbital label,
384- // which equals gvdm * orbital_pdm_shell ,
385- // orbital_pdm_shell[ Inl,nm* nm] = dm_hl * overlap * overlap
366+ // which equals gvdm * orbital_pdm ,
367+ // orbital_pdm[nks, Inl,nm, nm] = dm_hl * overlap * overlap
386368 // 11. cal_v_delta_precalc : v_delta_precalc is used for training with v_delta label,
387369 // which equals gvdm * v_delta_pdm_shell,
388370 // v_delta_pdm_shell = overlap * overlap
@@ -408,11 +390,11 @@ class LCAO_Deepks
408390 // / - b: the atoms whose force being calculated)
409391 // / gvdm*gdmx->gvx
410392 // /----------------------------------------------------
411- void cal_gvx (const int nat);
393+ void cal_gvx (const int nat, const std::vector<torch::Tensor>& gevdm );
412394 void check_gvx (const int nat);
413395
414396 // for stress
415- void cal_gvepsl (const int nat);
397+ void cal_gvepsl (const int nat, const std::vector<torch::Tensor>& gevdm );
416398
417399 // load the trained neural network model
418400 void load_model (const std::string& model_file);
@@ -423,20 +405,22 @@ class LCAO_Deepks
423405 void cal_gedm_equiv (const int nat);
424406
425407 // calculates orbital_precalc
426- template <typename TK, typename TH>
427- void cal_orbital_precalc (const std::vector<TH>& dm_hl,
428- const int lmaxd,
429- const int inlmax,
430- const int nat,
431- const int nks,
432- const int * inl_l,
433- const std::vector<ModuleBase::Vector3<double >>& kvec_d,
434- const std::vector<hamilt::HContainer<double >*> phialpha,
435- const ModuleBase::IntArray* inl_index,
436- const UnitCell& ucell,
437- const LCAO_Orbitals& orb,
438- const Parallel_Orbitals& pv,
439- const Grid_Driver& GridD);
408+ // template <typename TK, typename TH>
409+ // void cal_orbital_precalc(const std::vector<TH>& dm_hl,
410+ // const int lmaxd,
411+ // const int inlmax,
412+ // const int nat,
413+ // const int nks,
414+ // const int* inl_l,
415+ // const std::vector<ModuleBase::Vector3<double>>& kvec_d,
416+ // const std::vector<hamilt::HContainer<double>*> phialpha,
417+ // const std::vector<torch::Tensor> gevdm,
418+ // const ModuleBase::IntArray* inl_index,
419+ // const UnitCell& ucell,
420+ // const LCAO_Orbitals& orb,
421+ // const Parallel_Orbitals& pv,
422+ // const Grid_Driver& GridD,
423+ // torch::Tensor& orbital_precalc);
440424
441425 // calculates v_delta_precalc
442426 template <typename TK>
@@ -466,11 +450,11 @@ class LCAO_Deepks
466450
467451 // prepare gevdm for outputting npy file
468452 void prepare_gevdm (const int nat, const LCAO_Orbitals& orb);
453+ void cal_gevdm (const int nat, std::vector<torch::Tensor>& gevdm);
469454 void check_vdp_gevdm (const int nat);
470455
471456 private:
472457 const Parallel_Orbitals* pv;
473- void cal_gvdm (const int nat);
474458
475459#ifdef __MPI
476460
0 commit comments