11// / 1. cal_descriptor : obtains descriptors which are eigenvalues of pdm
22// / by calling torch::linalg::eigh
33// / 2. check_descriptor : prints descriptor for checking
4+ // / 3. cal_descriptor_equiv : calculates descriptor in equivalent version
45
56#ifdef __DEEPKS
67
7- #include " LCAO_deepks.h"
8+ #include " deepks_descriptor.h"
9+
810#include " LCAO_deepks_io.h" // mohan add 2024-07-22
911#include " module_base/blas_connector.h"
1012#include " module_base/constants.h"
1315#include " module_hamilt_lcao/module_hcontainer/atom_pair.h"
1416#include " module_parameter/parameter.h"
1517
16- void LCAO_Deepks::cal_descriptor_equiv (const int nat, std::vector<torch::Tensor>& descriptor)
18+ void DeePKS_domain::cal_descriptor_equiv (const int nat,
19+ const int des_per_atom,
20+ const std::vector<torch::Tensor>& pdm,
21+ std::vector<torch::Tensor>& descriptor)
1722{
18- ModuleBase::TITLE (" LCAO_Deepks " , " cal_descriptor_equiv" );
19- ModuleBase::timer::tick (" LCAO_Deepks " , " cal_descriptor_equiv" );
23+ ModuleBase::TITLE (" DeePKS_domain " , " cal_descriptor_equiv" );
24+ ModuleBase::timer::tick (" DeePKS_domain " , " cal_descriptor_equiv" );
2025
26+ assert (des_per_atom > 0 );
2127 for (int iat = 0 ; iat < nat; iat++)
2228 {
2329 auto tmp = torch::zeros (des_per_atom, torch::kFloat64 );
24- std::memcpy (tmp.data_ptr (), this -> pdm [iat].data_ptr <double >(), sizeof (double ) * tmp.numel ());
30+ std::memcpy (tmp.data_ptr (), pdm[iat].data_ptr <double >(), sizeof (double ) * tmp.numel ());
2531 descriptor.push_back (tmp);
2632 }
2733
28- ModuleBase::timer::tick (" LCAO_Deepks " , " cal_descriptor_equiv" );
34+ ModuleBase::timer::tick (" DeePKS_domain " , " cal_descriptor_equiv" );
2935}
3036
3137// calculates descriptors from projected density matrices
32- void LCAO_Deepks::cal_descriptor (const int nat, std::vector<torch::Tensor>& descriptor)
38+ void DeePKS_domain::cal_descriptor (const int nat,
39+ const int inlmax,
40+ const int * inl_l,
41+ const std::vector<torch::Tensor>& pdm,
42+ std::vector<torch::Tensor>& descriptor,
43+ const int des_per_atom = -1 )
3344{
34- ModuleBase::TITLE (" LCAO_Deepks" , " cal_descriptor" );
35- ModuleBase::timer::tick (" LCAO_Deepks" , " cal_descriptor" );
36-
37- // init descriptor
38- // if descriptor is not empty, clear it !!
39- if (!descriptor.empty ())
40- {
41- descriptor.erase (descriptor.begin (), descriptor.end ());
42- }
45+ ModuleBase::TITLE (" DeePKS_domain" , " cal_descriptor" );
46+ ModuleBase::timer::tick (" DeePKS_domain" , " cal_descriptor" );
4347
4448 if (PARAM.inp .deepks_equiv )
4549 {
46- this -> cal_descriptor_equiv (nat, descriptor);
50+ DeePKS_domain:: cal_descriptor_equiv (nat, des_per_atom, pdm , descriptor);
4751 return ;
4852 }
4953
50- for (int inl = 0 ; inl < this -> inlmax ; ++inl)
54+ for (int inl = 0 ; inl < inlmax; ++inl)
5155 {
5256 const int nm = 2 * inl_l[inl] + 1 ;
53- this -> pdm [inl].requires_grad_ (true );
57+ pdm[inl].requires_grad_ (true );
5458 descriptor.push_back (torch::ones ({nm}, torch::requires_grad (true )));
5559 }
5660
@@ -64,15 +68,18 @@ void LCAO_Deepks::cal_descriptor(const int nat, std::vector<torch::Tensor>& desc
6468 d_v = torch::linalg::eigh (pdm[inl], /* uplo*/ " U" );
6569 descriptor[inl] = std::get<0 >(d_v);
6670 }
67- ModuleBase::timer::tick (" LCAO_Deepks " , " cal_descriptor" );
71+ ModuleBase::timer::tick (" DeePKS_domain " , " cal_descriptor" );
6872 return ;
6973}
7074
71- void LCAO_Deepks::check_descriptor (const UnitCell& ucell,
72- const std::string& out_dir,
73- const std::vector<torch::Tensor>& descriptor)
75+ void DeePKS_domain::check_descriptor (const int inlmax,
76+ const int des_per_atom,
77+ const int * inl_l,
78+ const UnitCell& ucell,
79+ const std::string& out_dir,
80+ const std::vector<torch::Tensor>& descriptor)
7481{
75- ModuleBase::TITLE (" LCAO_Deepks " , " check_descriptor" );
82+ ModuleBase::TITLE (" DeePKS_domain " , " check_descriptor" );
7683
7784 if (GlobalV::MY_RANK != 0 )
7885 {
@@ -91,7 +98,7 @@ void LCAO_Deepks::check_descriptor(const UnitCell& ucell,
9198 for (int ia = 0 ; ia < ucell.atoms [it].na ; ia++)
9299 {
93100 int iat = ucell.itia2iat (it, ia);
94- ofs << ucell.atoms [it].label << " atom_index " << ia + 1 << " n_descriptor " << this -> des_per_atom
101+ ofs << ucell.atoms [it].label << " atom_index " << ia + 1 << " n_descriptor " << des_per_atom
95102 << std::endl;
96103 int id = 0 ;
97104 for (int inl = 0 ; inl < inlmax / ucell.nat ; inl++)
@@ -118,10 +125,9 @@ void LCAO_Deepks::check_descriptor(const UnitCell& ucell,
118125 for (int iat = 0 ; iat < ucell.nat ; iat++)
119126 {
120127 const int it = ucell.iat2it [iat];
121- ofs << ucell.atoms [it].label << " atom_index " << iat + 1 << " n_descriptor " << this ->des_per_atom
122- << std::endl;
128+ ofs << ucell.atoms [it].label << " atom_index " << iat + 1 << " n_descriptor " << des_per_atom << std::endl;
123129 auto accessor = descriptor[iat].accessor <double , 1 >();
124- for (int i = 0 ; i < this -> des_per_atom ; i++)
130+ for (int i = 0 ; i < des_per_atom; i++)
125131 {
126132 ofs << accessor[i] << " " ;
127133 if (i % 8 == 7 )
0 commit comments