Skip to content

Commit 865f3da

Browse files
committed
Move functions for calculating descriptor from LCAO_deepks to DeePKS_domain.
1 parent e1c9ae3 commit 865f3da

File tree

10 files changed

+136
-52
lines changed

10 files changed

+136
-52
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ OBJS_CELL=atom_pseudo.o\
192192

193193
OBJS_DEEPKS=LCAO_deepks.o\
194194
deepks_force.o\
195+
deepks_descriptor.o\
195196
deepks_orbital.o\
196197
deepks_orbpre.o\
197198
deepks_vdpre.o\
@@ -206,7 +207,6 @@ OBJS_DEEPKS=LCAO_deepks.o\
206207
cal_gdmepsl.o\
207208
cal_gedm.o\
208209
cal_gvx.o\
209-
cal_descriptor.o\
210210

211211

212212
OBJS_ELECSTAT=elecstate.o\

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,12 @@ void Force_LCAO<double>::ftable(const bool isforce,
254254
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
255255
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);
256256

257-
GlobalC::ld.cal_descriptor(ucell.nat, descriptor);
257+
DeePKS_domain::cal_descriptor(ucell.nat,
258+
GlobalC::ld.inlmax,
259+
GlobalC::ld.inl_l,
260+
GlobalC::ld.pdm,
261+
descriptor,
262+
GlobalC::ld.des_per_atom);
258263
GlobalC::ld.cal_gedm(ucell.nat, descriptor);
259264

260265
const int nks = 1;
@@ -306,7 +311,12 @@ void Force_LCAO<double>::ftable(const bool isforce,
306311

307312
GlobalC::ld.check_projected_dm();
308313

309-
GlobalC::ld.check_descriptor(ucell, PARAM.globalv.global_out_dir, descriptor);
314+
DeePKS_domain::check_descriptor(GlobalC::ld.inlmax,
315+
GlobalC::ld.des_per_atom,
316+
GlobalC::ld.inl_l,
317+
ucell,
318+
PARAM.globalv.global_out_dir,
319+
descriptor);
310320

311321
GlobalC::ld.check_gedm();
312322

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,12 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
350350
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);
351351

352352
std::vector<torch::Tensor> descriptor;
353-
GlobalC::ld.cal_descriptor(ucell.nat, descriptor);
353+
DeePKS_domain::cal_descriptor(ucell.nat,
354+
GlobalC::ld.inlmax,
355+
GlobalC::ld.inl_l,
356+
GlobalC::ld.pdm,
357+
descriptor,
358+
GlobalC::ld.des_per_atom);
354359
GlobalC::ld.cal_gedm(ucell.nat, descriptor);
355360

356361
DeePKS_domain::cal_f_delta<std::complex<double>>(dm_k,

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,12 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
163163
GlobalC::ld.cal_projected_DM<TK>(this->DM, *this->ucell, *ptr_orb_, *(this->gd));
164164

165165
std::vector<torch::Tensor> descriptor;
166-
GlobalC::ld.cal_descriptor(this->ucell->nat, descriptor);
166+
DeePKS_domain::cal_descriptor(this->ucell->nat,
167+
GlobalC::ld.inlmax,
168+
GlobalC::ld.inl_l,
169+
GlobalC::ld.pdm,
170+
descriptor,
171+
GlobalC::ld.des_per_atom);
167172
GlobalC::ld.cal_gedm(this->ucell->nat, descriptor);
168173

169174
// // recalculate the H_V_delta

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
if(ENABLE_DEEPKS)
22
list(APPEND objects
33
LCAO_deepks.cpp
4+
deepks_descriptor.cpp
45
deepks_force.cpp
56
deepks_orbital.cpp
67
deepks_orbpre.cpp
@@ -16,8 +17,6 @@ if(ENABLE_DEEPKS)
1617
cal_gdmepsl.cpp
1718
cal_gedm.cpp
1819
cal_gvx.cpp
19-
cal_descriptor.cpp
20-
2120
)
2221

2322
add_library(

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#ifdef __DEEPKS
55

6+
#include "deepks_descriptor.h"
67
#include "deepks_force.h"
78
#include "deepks_hmat.h"
89
#include "deepks_orbital.h"
@@ -309,9 +310,6 @@ class LCAO_Deepks
309310
// as well as subroutines that prints the results for checking
310311

311312
// The file contains 8 subroutines:
312-
// 1. cal_descriptor : obtains descriptors which are eigenvalues of pdm
313-
// by calling torch::linalg::eigh
314-
// 2. check_descriptor : prints descriptor for checking
315313
// 3. cal_gvx : gvx is used for training with force label, which is gradient of descriptors,
316314
// calculated by d(des)/dX = d(pdm)/dX * d(des)/d(pdm) = gdmx * gvdm
317315
// using einsum
@@ -329,16 +327,6 @@ class LCAO_Deepks
329327
// 9. check_gedm : prints gedm for checking
330328

331329
public:
332-
/// Calculates descriptors
333-
/// which are eigenvalues of pdm in blocks of I_n_l
334-
void cal_descriptor(const int nat, std::vector<torch::Tensor>& descriptor);
335-
/// print descriptors based on LCAO basis
336-
void check_descriptor(const UnitCell& ucell,
337-
const std::string& out_dir,
338-
const std::vector<torch::Tensor>& descriptor);
339-
340-
void cal_descriptor_equiv(const int nat, std::vector<torch::Tensor>& descriptor);
341-
342330
/// calculates gradient of descriptors w.r.t atomic positions
343331
///----------------------------------------------------
344332
/// m, n: 2*l+1

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,18 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
298298
ld->check_projected_dm(); // print out the projected dm for NSCF calculaiton
299299

300300
std::vector<torch::Tensor> descriptor;
301-
ld->cal_descriptor(nat, descriptor); // final descriptor
302-
ld->check_descriptor(ucell, PARAM.globalv.global_out_dir, descriptor);
301+
DeePKS_domain::cal_descriptor(nat,
302+
ld->inlmax,
303+
ld->inl_l,
304+
ld->pdm,
305+
descriptor,
306+
ld->des_per_atom); // final descriptor
307+
DeePKS_domain::check_descriptor(ld->inlmax,
308+
ld->des_per_atom,
309+
ld->inl_l,
310+
ucell,
311+
PARAM.globalv.global_out_dir,
312+
descriptor);
303313

304314
if (PARAM.inp.deepks_out_labels)
305315
{

source/module_hamilt_lcao/module_deepks/cal_descriptor.cpp renamed to source/module_hamilt_lcao/module_deepks/deepks_descriptor.cpp

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
/// 1. cal_descriptor : obtains descriptors which are eigenvalues of pdm
22
/// by calling torch::linalg::eigh
33
/// 2. check_descriptor : prints descriptor for checking
4+
/// 3. cal_descriptor_equiv : calculates descriptor in equivalent version
45

56
#ifdef __DEEPKS
67

7-
#include "LCAO_deepks.h"
8+
#include "deepks_descriptor.h"
9+
810
#include "LCAO_deepks_io.h" // mohan add 2024-07-22
911
#include "module_base/blas_connector.h"
1012
#include "module_base/constants.h"
@@ -13,44 +15,46 @@
1315
#include "module_hamilt_lcao/module_hcontainer/atom_pair.h"
1416
#include "module_parameter/parameter.h"
1517

16-
void LCAO_Deepks::cal_descriptor_equiv(const int nat, std::vector<torch::Tensor>& descriptor)
18+
void DeePKS_domain::cal_descriptor_equiv(const int nat,
19+
const int des_per_atom,
20+
const std::vector<torch::Tensor>& pdm,
21+
std::vector<torch::Tensor>& descriptor)
1722
{
18-
ModuleBase::TITLE("LCAO_Deepks", "cal_descriptor_equiv");
19-
ModuleBase::timer::tick("LCAO_Deepks", "cal_descriptor_equiv");
23+
ModuleBase::TITLE("DeePKS_domain", "cal_descriptor_equiv");
24+
ModuleBase::timer::tick("DeePKS_domain", "cal_descriptor_equiv");
2025

26+
assert(des_per_atom > 0);
2127
for (int iat = 0; iat < nat; iat++)
2228
{
2329
auto tmp = torch::zeros(des_per_atom, torch::kFloat64);
24-
std::memcpy(tmp.data_ptr(), this->pdm[iat].data_ptr<double>(), sizeof(double) * tmp.numel());
30+
std::memcpy(tmp.data_ptr(), pdm[iat].data_ptr<double>(), sizeof(double) * tmp.numel());
2531
descriptor.push_back(tmp);
2632
}
2733

28-
ModuleBase::timer::tick("LCAO_Deepks", "cal_descriptor_equiv");
34+
ModuleBase::timer::tick("DeePKS_domain", "cal_descriptor_equiv");
2935
}
3036

3137
// calculates descriptors from projected density matrices
32-
void LCAO_Deepks::cal_descriptor(const int nat, std::vector<torch::Tensor>& descriptor)
38+
void DeePKS_domain::cal_descriptor(const int nat,
39+
const int inlmax,
40+
const int* inl_l,
41+
const std::vector<torch::Tensor>& pdm,
42+
std::vector<torch::Tensor>& descriptor,
43+
const int des_per_atom = -1)
3344
{
34-
ModuleBase::TITLE("LCAO_Deepks", "cal_descriptor");
35-
ModuleBase::timer::tick("LCAO_Deepks", "cal_descriptor");
36-
37-
// init descriptor
38-
// if descriptor is not empty, clear it !!
39-
if (!descriptor.empty())
40-
{
41-
descriptor.erase(descriptor.begin(), descriptor.end());
42-
}
45+
ModuleBase::TITLE("DeePKS_domain", "cal_descriptor");
46+
ModuleBase::timer::tick("DeePKS_domain", "cal_descriptor");
4347

4448
if (PARAM.inp.deepks_equiv)
4549
{
46-
this->cal_descriptor_equiv(nat, descriptor);
50+
DeePKS_domain::cal_descriptor_equiv(nat, des_per_atom, pdm, descriptor);
4751
return;
4852
}
4953

50-
for (int inl = 0; inl < this->inlmax; ++inl)
54+
for (int inl = 0; inl < inlmax; ++inl)
5155
{
5256
const int nm = 2 * inl_l[inl] + 1;
53-
this->pdm[inl].requires_grad_(true);
57+
pdm[inl].requires_grad_(true);
5458
descriptor.push_back(torch::ones({nm}, torch::requires_grad(true)));
5559
}
5660

@@ -64,15 +68,18 @@ void LCAO_Deepks::cal_descriptor(const int nat, std::vector<torch::Tensor>& desc
6468
d_v = torch::linalg::eigh(pdm[inl], /*uplo*/ "U");
6569
descriptor[inl] = std::get<0>(d_v);
6670
}
67-
ModuleBase::timer::tick("LCAO_Deepks", "cal_descriptor");
71+
ModuleBase::timer::tick("DeePKS_domain", "cal_descriptor");
6872
return;
6973
}
7074

71-
void LCAO_Deepks::check_descriptor(const UnitCell& ucell,
72-
const std::string& out_dir,
73-
const std::vector<torch::Tensor>& descriptor)
75+
void DeePKS_domain::check_descriptor(const int inlmax,
76+
const int des_per_atom,
77+
const int* inl_l,
78+
const UnitCell& ucell,
79+
const std::string& out_dir,
80+
const std::vector<torch::Tensor>& descriptor)
7481
{
75-
ModuleBase::TITLE("LCAO_Deepks", "check_descriptor");
82+
ModuleBase::TITLE("DeePKS_domain", "check_descriptor");
7683

7784
if (GlobalV::MY_RANK != 0)
7885
{
@@ -91,7 +98,7 @@ void LCAO_Deepks::check_descriptor(const UnitCell& ucell,
9198
for (int ia = 0; ia < ucell.atoms[it].na; ia++)
9299
{
93100
int iat = ucell.itia2iat(it, ia);
94-
ofs << ucell.atoms[it].label << " atom_index " << ia + 1 << " n_descriptor " << this->des_per_atom
101+
ofs << ucell.atoms[it].label << " atom_index " << ia + 1 << " n_descriptor " << des_per_atom
95102
<< std::endl;
96103
int id = 0;
97104
for (int inl = 0; inl < inlmax / ucell.nat; inl++)
@@ -118,10 +125,9 @@ void LCAO_Deepks::check_descriptor(const UnitCell& ucell,
118125
for (int iat = 0; iat < ucell.nat; iat++)
119126
{
120127
const int it = ucell.iat2it[iat];
121-
ofs << ucell.atoms[it].label << " atom_index " << iat + 1 << " n_descriptor " << this->des_per_atom
122-
<< std::endl;
128+
ofs << ucell.atoms[it].label << " atom_index " << iat + 1 << " n_descriptor " << des_per_atom << std::endl;
123129
auto accessor = descriptor[iat].accessor<double, 1>();
124-
for (int i = 0; i < this->des_per_atom; i++)
130+
for (int i = 0; i < des_per_atom; i++)
125131
{
126132
ofs << accessor[i] << " ";
127133
if (i % 8 == 7)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#ifndef DEEPKS_DESCRIPTOR_H
2+
#define DEEPKS_DESCRIPTOR_H
3+
4+
#ifdef __DEEPKS
5+
6+
#include "module_base/intarray.h"
7+
#include "module_base/timer.h"
8+
#include "module_cell/unitcell.h"
9+
10+
#include <torch/script.h>
11+
#include <torch/torch.h>
12+
13+
namespace DeePKS_domain
14+
{
15+
//------------------------
16+
// deepks_descriptor.cpp
17+
//------------------------
18+
19+
// This file contains interfaces with libtorch,
20+
// including loading of model and calculating gradients
21+
// as well as subroutines that prints the results for checking
22+
23+
// The file contains 8 subroutines:
24+
// 1. cal_descriptor : obtains descriptors which are eigenvalues of pdm
25+
// by calling torch::linalg::eigh
26+
// 2. check_descriptor : prints descriptor for checking
27+
// 3. cal_descriptor_equiv : calculates descriptor in equivalent version
28+
29+
/// Calculates descriptors
30+
/// which are eigenvalues of pdm in blocks of I_n_l
31+
void cal_descriptor(const int nat,
32+
const int inlmax,
33+
const int* inl_l,
34+
const std::vector<torch::Tensor>& pdm,
35+
std::vector<torch::Tensor>& descriptor,
36+
const int des_per_atom);
37+
/// print descriptors based on LCAO basis
38+
void check_descriptor(const int inlmax,
39+
const int des_per_atom,
40+
const int* inl_l,
41+
const UnitCell& ucell,
42+
const std::string& out_dir,
43+
const std::vector<torch::Tensor>& descriptor);
44+
45+
void cal_descriptor_equiv(const int nat,
46+
const int des_per_atom,
47+
const std::vector<torch::Tensor>& pdm,
48+
std::vector<torch::Tensor>& descriptor);
49+
} // namespace DeePKS_domain
50+
#endif
51+
#endif

source/module_hamilt_lcao/module_deepks/test/LCAO_deepks_test.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,18 @@ void test_deepks::check_gdmx(torch::Tensor& gdmx)
198198

199199
void test_deepks::check_descriptor(std::vector<torch::Tensor>& descriptor)
200200
{
201-
GlobalC::ld.cal_descriptor(ucell.nat, descriptor);
202-
GlobalC::ld.check_descriptor(ucell, "./", descriptor);
201+
DeePKS_domain::cal_descriptor(ucell.nat,
202+
GlobalC::ld.inlmax,
203+
GlobalC::ld.inl_l,
204+
GlobalC::ld.pdm,
205+
descriptor,
206+
GlobalC::ld.des_per_atom);
207+
DeePKS_domain::check_descriptor(GlobalC::ld.inlmax,
208+
GlobalC::ld.des_per_atom,
209+
GlobalC::ld.inl_l,
210+
ucell,
211+
"./",
212+
descriptor);
203213
this->compare_with_ref("deepks_desc.dat", "descriptor_ref.dat");
204214
}
205215

0 commit comments

Comments
 (0)