Skip to content

Commit 86e9314

Browse files
authored
Remove the use of std::system in DeePKS_equiv. (#7122)
* Remove the use of std::system in DeePKS_equiv. * Fix a dimension match problem.
1 parent 27bff48 commit 86e9314

7 files changed

Lines changed: 119 additions & 63 deletions

File tree

source/source_lcao/module_deepks/LCAO_deepks.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ void LCAO_Deepks<T>::init(const LCAO_Orbitals& orb,
7676

7777
this->deepks_param.lmaxd = lm;
7878
this->deepks_param.nmaxd = nm;
79+
this->deepks_param.nchi_d_l.assign(lm + 1, 0);
80+
for (int l = 0; l <= lm; ++l)
81+
{
82+
this->deepks_param.nchi_d_l[l] = orb.Alpha[0].getNchi(l);
83+
}
7984

8085
ofs << " lmax of descriptor = " << deepks_param.lmaxd << std::endl;
8186
ofs << " nmax of descriptor = " << deepks_param.nmaxd << std::endl;

source/source_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
135135
// new gedm is also useful in cal_f_delta, so it should be ld->gedm
136136
if (PARAM.inp.deepks_equiv)
137137
{
138-
DeePKS_domain::cal_edelta_gedm_equiv(nat, deepks_param, descriptor, ld->gedm, E_delta, rank);
138+
DeePKS_domain::cal_edelta_gedm_equiv(nat, deepks_param, descriptor, ld->model_deepks, ld->gedm, E_delta, rank);
139139
}
140140
else
141141
{

source/source_lcao/module_deepks/deepks_basic.cpp

Lines changed: 109 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#include "source_base/timer.h"
1010
#include "source_io/module_parameter/parameter.h"
1111

12-
#include <cstdlib> // use system command
12+
#include <cmath>
13+
14+
#ifdef __MPI
15+
#include <mpi.h>
16+
#endif
1317

1418
// d(Descriptor) / d(projected density matrix)
1519
// Dimension is different for each inl, so there's a vector of tensors
@@ -85,82 +89,125 @@ void DeePKS_domain::load_model(const std::string& model_file, torch::jit::script
8589
return;
8690
}
8791

88-
inline void generate_py_files(const DeePKS_Param& deepks_param, const std::string& out_dir)
89-
{
90-
std::ofstream ofs("cal_edelta_gedm.py");
91-
ofs << "import torch" << std::endl;
92-
ofs << "import numpy as np" << std::endl << std::endl;
93-
ofs << "import sys" << std::endl;
94-
95-
ofs << "from deepks.scf.enn.scf import BasisInfo" << std::endl;
96-
ofs << "from deepks.iterate.template_abacus import t_make_pdm" << std::endl;
97-
ofs << "from deepks.utils import load_yaml" << std::endl << std::endl;
98-
99-
ofs << "basis = load_yaml('basis.yaml')['proj_basis']" << std::endl;
100-
ofs << "model = torch.jit.load(sys.argv[1])" << std::endl;
101-
ofs << "dm_eig = np.expand_dims(np.load('" << out_dir << "dm_eig.npy'),0)" << std::endl;
102-
ofs << "dm_eig = torch.tensor(dm_eig, "
103-
"dtype=torch.float64,requires_grad=True)"
104-
<< std::endl
105-
<< std::endl;
106-
107-
ofs << "dm_flat,basis_info = t_make_pdm(dm_eig,basis)" << std::endl;
108-
ofs << "ec = model(dm_flat.double())" << std::endl;
109-
ofs << "gedm = "
110-
"torch.autograd.grad(ec,dm_eig,grad_outputs=torch.ones_like(ec))[0]"
111-
<< std::endl
112-
<< std::endl;
113-
114-
ofs << "np.save('ec.npy',ec.double().detach().numpy())" << std::endl;
115-
ofs << "np.save('gedm.npy',gedm.double().numpy())" << std::endl;
116-
ofs.close();
117-
118-
ofs.open("basis.yaml");
119-
ofs << "proj_basis:" << std::endl;
120-
for (int l = 0; l < deepks_param.lmaxd + 1; l++)
121-
{
122-
ofs << " - - " << l << std::endl;
123-
ofs << " - [";
124-
for (int i = 0; i < deepks_param.nmaxd + 1; i++)
125-
{
126-
ofs << "0";
127-
if (i != deepks_param.nmaxd)
128-
{
129-
ofs << ", ";
130-
}
131-
}
132-
ofs << "]" << std::endl;
133-
}
134-
}
135-
13692
void DeePKS_domain::cal_edelta_gedm_equiv(const int nat,
13793
const DeePKS_Param& deepks_param,
13894
const std::vector<torch::Tensor>& descriptor,
95+
torch::jit::script::Module& model_deepks,
13996
double** gedm,
14097
double& E_delta,
14198
const int rank)
14299
{
143100
ModuleBase::TITLE("DeePKS_domain", "cal_edelta_gedm_equiv");
144101
ModuleBase::timer::start("DeePKS_domain", "cal_edelta_gedm_equiv");
145102

146-
const std::string file_d = PARAM.globalv.global_out_dir + "deepks_dm_eig.npy";
147-
LCAO_deepks_io::save_npy_d(nat, PARAM.inp.deepks_equiv, deepks_param, descriptor, file_d,
148-
rank); // libnpy needed
149-
150103
if (rank == 0)
151104
{
152-
generate_py_files(deepks_param, PARAM.globalv.global_out_dir);
153-
std::string cmd = "python cal_edelta_gedm.py " + PARAM.inp.deepks_model;
154-
int stat = std::system(cmd.c_str());
155-
assert(stat == 0);
156-
}
105+
const int basis_size
106+
= static_cast<int>(std::llround(std::sqrt(static_cast<double>(deepks_param.des_per_atom))));
107+
if (basis_size * basis_size != deepks_param.des_per_atom)
108+
{
109+
ModuleBase::WARNING_QUIT("DeePKS_domain::cal_edelta_gedm_equiv",
110+
"Invalid des_per_atom for equivariant DeePKS: it must be a perfect square.");
111+
}
112+
113+
torch::Tensor dm_eig = torch::cat(descriptor, 0).reshape({1, nat, deepks_param.des_per_atom});
114+
dm_eig = dm_eig.to(torch::kFloat64).requires_grad_(true);
115+
torch::Tensor dm = dm_eig.reshape({1, nat, basis_size, basis_size});
116+
117+
if (static_cast<int>(deepks_param.nchi_d_l.size()) != deepks_param.lmaxd + 1)
118+
{
119+
ModuleBase::WARNING_QUIT(
120+
"DeePKS_domain::cal_edelta_gedm_equiv",
121+
"Invalid nchi_d_l in DeePKS parameters: expected size lmaxd + 1 for equivariant shell construction.");
122+
}
123+
124+
std::vector<torch::Tensor> ovlp_shells;
125+
int total_shells = 0;
126+
for (int l = 0; l <= deepks_param.lmaxd; ++l)
127+
{
128+
total_shells += deepks_param.nchi_d_l[l];
129+
}
130+
ovlp_shells.reserve(total_shells);
131+
int offset = 0;
132+
for (int l = 0; l <= deepks_param.lmaxd; ++l)
133+
{
134+
const int nm = 2 * l + 1;
135+
for (int n = 0; n < deepks_param.nchi_d_l[l]; ++n)
136+
{
137+
torch::Tensor po = torch::zeros({basis_size, 1, nm}, torch::TensorOptions().dtype(torch::kFloat64));
138+
auto accessor = po.accessor<double, 3>();
139+
for (int m = 0; m < nm; ++m)
140+
{
141+
accessor[offset + m][0][m] = 1.0;
142+
}
143+
ovlp_shells.push_back(po);
144+
offset += nm;
145+
}
146+
}
147+
if (offset != basis_size)
148+
{
149+
ModuleBase::WARNING_QUIT("DeePKS_domain::cal_edelta_gedm_equiv",
150+
"Invalid shell layout: accumulated shell offset does not match basis size.");
151+
}
157152

158-
MPI_Barrier(MPI_COMM_WORLD);
153+
std::vector<torch::Tensor> dm_flat;
154+
dm_flat.reserve(ovlp_shells.size());
155+
for (const auto& po : ovlp_shells)
156+
{
157+
// Equivalent to python:
158+
// torch.einsum('rap,...rs,saq->...apq', po, dm, po)
159+
torch::Tensor pdm_shell = torch::einsum("rap,...rs,saq->...apq", {po, dm, po});
160+
dm_flat.push_back(pdm_shell.squeeze(-3));
161+
}
159162

160-
LCAO_deepks_io::load_npy_gedm(nat, deepks_param.des_per_atom, gedm, E_delta, rank);
163+
c10::List<torch::Tensor> model_input;
164+
for (const auto& pdm_shell : dm_flat)
165+
{
166+
model_input.push_back(pdm_shell);
167+
}
168+
169+
std::vector<torch::jit::IValue> inputs;
170+
inputs.emplace_back(model_input);
171+
172+
torch::Tensor ec;
173+
try
174+
{
175+
ec = model_deepks.forward(inputs).toTensor(); // Hartree
176+
}
177+
catch (const c10::Error& e)
178+
{
179+
ModuleBase::WARNING_QUIT("DeePKS_domain::cal_edelta_gedm_equiv",
180+
"Failed to evaluate equivariant DeePKS model in C++.");
181+
throw;
182+
}
161183

162-
std::string cmd = "rm -f cal_edelta_gedm.py basis.yaml ec.npy gedm.npy";
163-
std::system(cmd.c_str());
184+
E_delta = ec.item<double>() * 2.0; // Hartree to Ry
185+
186+
std::vector<torch::Tensor> grad_outputs{torch::ones_like(ec)};
187+
std::vector<torch::Tensor> grad_inputs{dm_eig};
188+
torch::Tensor gedm_tensor = torch::autograd::grad({ec}, grad_inputs, grad_outputs,
189+
/*retain_graph=*/false,
190+
/*create_graph=*/false,
191+
/*allow_unused=*/false)[0];
192+
193+
torch::Tensor gedm_nat = gedm_tensor.reshape({nat, deepks_param.des_per_atom});
194+
auto accessor = gedm_nat.accessor<double, 2>();
195+
for (int iat = 0; iat < nat; ++iat)
196+
{
197+
for (int ides = 0; ides < deepks_param.des_per_atom; ++ides)
198+
{
199+
gedm[iat][ides] = accessor[iat][ides] * 2.0; // Hartree to Ry
200+
}
201+
}
202+
}
203+
204+
#ifdef __MPI
205+
for (int iat = 0; iat < nat; ++iat)
206+
{
207+
MPI_Bcast(gedm[iat], deepks_param.des_per_atom, MPI_DOUBLE, 0, MPI_COMM_WORLD);
208+
}
209+
MPI_Bcast(&E_delta, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
210+
#endif
164211

165212
ModuleBase::timer::end("DeePKS_domain", "cal_edelta_gedm_equiv");
166213
return;

source/source_lcao/module_deepks/deepks_basic.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ void check_gedm(const DeePKS_Param& deepks_param, double** gedm);
4949
void cal_edelta_gedm_equiv(const int nat,
5050
const DeePKS_Param& deepks_param,
5151
const std::vector<torch::Tensor>& descriptor,
52+
torch::jit::script::Module& model_deepks,
5253
double** gedm,
5354
double& E_delta,
5455
const int rank);

source/source_lcao/module_deepks/deepks_param.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ struct DeePKS_Param
1414
int inlmax = 0;
1515
int n_descriptor = 0;
1616
int des_per_atom = 0;
17+
std::vector<int> nchi_d_l;
1718
std::vector<int> inl2l;
1819
ModuleBase::IntArray* inl_index = nullptr;
1920
};

source/source_lcao/module_deepks/test/LCAO_deepks_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ void test_deepks<T>::check_edelta(std::vector<torch::Tensor>& descriptor)
320320
DeePKS_domain::cal_edelta_gedm_equiv(ucell.nat,
321321
this->ld.deepks_param,
322322
descriptor,
323+
this->ld.model_deepks,
323324
this->ld.gedm,
324325
this->ld.E_delta,
325326
0); // 0 for rank

source/source_lcao/module_operator_lcao/deepks_lcao.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
173173
DeePKS_domain::cal_edelta_gedm_equiv(this->ucell->nat,
174174
this->ld->deepks_param,
175175
descriptor,
176+
this->ld->model_deepks,
176177
this->ld->gedm,
177178
this->ld->E_delta,
178179
GlobalV::MY_RANK);

0 commit comments

Comments
 (0)