|
9 | 9 | #include "source_base/timer.h" |
10 | 10 | #include "source_io/module_parameter/parameter.h" |
11 | 11 |
|
12 | | -#include <cstdlib> // use system command |
| 12 | +#include <cmath> |
| 13 | + |
| 14 | +#ifdef __MPI |
| 15 | +#include <mpi.h> |
| 16 | +#endif |
13 | 17 |
|
14 | 18 | // d(Descriptor) / d(projected density matrix) |
15 | 19 | // Dimension is different for each inl, so there's a vector of tensors |
@@ -85,82 +89,125 @@ void DeePKS_domain::load_model(const std::string& model_file, torch::jit::script |
85 | 89 | return; |
86 | 90 | } |
87 | 91 |
|
88 | | -inline void generate_py_files(const DeePKS_Param& deepks_param, const std::string& out_dir) |
89 | | -{ |
90 | | - std::ofstream ofs("cal_edelta_gedm.py"); |
91 | | - ofs << "import torch" << std::endl; |
92 | | - ofs << "import numpy as np" << std::endl << std::endl; |
93 | | - ofs << "import sys" << std::endl; |
94 | | - |
95 | | - ofs << "from deepks.scf.enn.scf import BasisInfo" << std::endl; |
96 | | - ofs << "from deepks.iterate.template_abacus import t_make_pdm" << std::endl; |
97 | | - ofs << "from deepks.utils import load_yaml" << std::endl << std::endl; |
98 | | - |
99 | | - ofs << "basis = load_yaml('basis.yaml')['proj_basis']" << std::endl; |
100 | | - ofs << "model = torch.jit.load(sys.argv[1])" << std::endl; |
101 | | - ofs << "dm_eig = np.expand_dims(np.load('" << out_dir << "dm_eig.npy'),0)" << std::endl; |
102 | | - ofs << "dm_eig = torch.tensor(dm_eig, " |
103 | | - "dtype=torch.float64,requires_grad=True)" |
104 | | - << std::endl |
105 | | - << std::endl; |
106 | | - |
107 | | - ofs << "dm_flat,basis_info = t_make_pdm(dm_eig,basis)" << std::endl; |
108 | | - ofs << "ec = model(dm_flat.double())" << std::endl; |
109 | | - ofs << "gedm = " |
110 | | - "torch.autograd.grad(ec,dm_eig,grad_outputs=torch.ones_like(ec))[0]" |
111 | | - << std::endl |
112 | | - << std::endl; |
113 | | - |
114 | | - ofs << "np.save('ec.npy',ec.double().detach().numpy())" << std::endl; |
115 | | - ofs << "np.save('gedm.npy',gedm.double().numpy())" << std::endl; |
116 | | - ofs.close(); |
117 | | - |
118 | | - ofs.open("basis.yaml"); |
119 | | - ofs << "proj_basis:" << std::endl; |
120 | | - for (int l = 0; l < deepks_param.lmaxd + 1; l++) |
121 | | - { |
122 | | - ofs << " - - " << l << std::endl; |
123 | | - ofs << " - ["; |
124 | | - for (int i = 0; i < deepks_param.nmaxd + 1; i++) |
125 | | - { |
126 | | - ofs << "0"; |
127 | | - if (i != deepks_param.nmaxd) |
128 | | - { |
129 | | - ofs << ", "; |
130 | | - } |
131 | | - } |
132 | | - ofs << "]" << std::endl; |
133 | | - } |
134 | | -} |
135 | | - |
136 | 92 | void DeePKS_domain::cal_edelta_gedm_equiv(const int nat, |
137 | 93 | const DeePKS_Param& deepks_param, |
138 | 94 | const std::vector<torch::Tensor>& descriptor, |
| 95 | + torch::jit::script::Module& model_deepks, |
139 | 96 | double** gedm, |
140 | 97 | double& E_delta, |
141 | 98 | const int rank) |
142 | 99 | { |
143 | 100 | ModuleBase::TITLE("DeePKS_domain", "cal_edelta_gedm_equiv"); |
144 | 101 | ModuleBase::timer::start("DeePKS_domain", "cal_edelta_gedm_equiv"); |
145 | 102 |
|
146 | | - const std::string file_d = PARAM.globalv.global_out_dir + "deepks_dm_eig.npy"; |
147 | | - LCAO_deepks_io::save_npy_d(nat, PARAM.inp.deepks_equiv, deepks_param, descriptor, file_d, |
148 | | - rank); // libnpy needed |
149 | | - |
150 | 103 | if (rank == 0) |
151 | 104 | { |
152 | | - generate_py_files(deepks_param, PARAM.globalv.global_out_dir); |
153 | | - std::string cmd = "python cal_edelta_gedm.py " + PARAM.inp.deepks_model; |
154 | | - int stat = std::system(cmd.c_str()); |
155 | | - assert(stat == 0); |
156 | | - } |
| 105 | + const int basis_size |
| 106 | + = static_cast<int>(std::llround(std::sqrt(static_cast<double>(deepks_param.des_per_atom)))); |
| 107 | + if (basis_size * basis_size != deepks_param.des_per_atom) |
| 108 | + { |
| 109 | + ModuleBase::WARNING_QUIT("DeePKS_domain::cal_edelta_gedm_equiv", |
| 110 | + "Invalid des_per_atom for equivariant DeePKS: it must be a perfect square."); |
| 111 | + } |
| 112 | + |
| 113 | + torch::Tensor dm_eig = torch::cat(descriptor, 0).reshape({1, nat, deepks_param.des_per_atom}); |
| 114 | + dm_eig = dm_eig.to(torch::kFloat64).requires_grad_(true); |
| 115 | + torch::Tensor dm = dm_eig.reshape({1, nat, basis_size, basis_size}); |
| 116 | + |
| 117 | + if (static_cast<int>(deepks_param.nchi_d_l.size()) != deepks_param.lmaxd + 1) |
| 118 | + { |
| 119 | + ModuleBase::WARNING_QUIT( |
| 120 | + "DeePKS_domain::cal_edelta_gedm_equiv", |
| 121 | + "Invalid nchi_d_l in DeePKS parameters: expected size lmaxd + 1 for equivariant shell construction."); |
| 122 | + } |
| 123 | + |
| 124 | + std::vector<torch::Tensor> ovlp_shells; |
| 125 | + int total_shells = 0; |
| 126 | + for (int l = 0; l <= deepks_param.lmaxd; ++l) |
| 127 | + { |
| 128 | + total_shells += deepks_param.nchi_d_l[l]; |
| 129 | + } |
| 130 | + ovlp_shells.reserve(total_shells); |
| 131 | + int offset = 0; |
| 132 | + for (int l = 0; l <= deepks_param.lmaxd; ++l) |
| 133 | + { |
| 134 | + const int nm = 2 * l + 1; |
| 135 | + for (int n = 0; n < deepks_param.nchi_d_l[l]; ++n) |
| 136 | + { |
| 137 | + torch::Tensor po = torch::zeros({basis_size, 1, nm}, torch::TensorOptions().dtype(torch::kFloat64)); |
| 138 | + auto accessor = po.accessor<double, 3>(); |
| 139 | + for (int m = 0; m < nm; ++m) |
| 140 | + { |
| 141 | + accessor[offset + m][0][m] = 1.0; |
| 142 | + } |
| 143 | + ovlp_shells.push_back(po); |
| 144 | + offset += nm; |
| 145 | + } |
| 146 | + } |
| 147 | + if (offset != basis_size) |
| 148 | + { |
| 149 | + ModuleBase::WARNING_QUIT("DeePKS_domain::cal_edelta_gedm_equiv", |
| 150 | + "Invalid shell layout: accumulated shell offset does not match basis size."); |
| 151 | + } |
157 | 152 |
|
158 | | - MPI_Barrier(MPI_COMM_WORLD); |
| 153 | + std::vector<torch::Tensor> dm_flat; |
| 154 | + dm_flat.reserve(ovlp_shells.size()); |
| 155 | + for (const auto& po : ovlp_shells) |
| 156 | + { |
| 157 | + // Equivalent to python: |
| 158 | + // torch.einsum('rap,...rs,saq->...apq', po, dm, po) |
| 159 | + torch::Tensor pdm_shell = torch::einsum("rap,...rs,saq->...apq", {po, dm, po}); |
| 160 | + dm_flat.push_back(pdm_shell.squeeze(-3)); |
| 161 | + } |
159 | 162 |
|
160 | | - LCAO_deepks_io::load_npy_gedm(nat, deepks_param.des_per_atom, gedm, E_delta, rank); |
| 163 | + c10::List<torch::Tensor> model_input; |
| 164 | + for (const auto& pdm_shell : dm_flat) |
| 165 | + { |
| 166 | + model_input.push_back(pdm_shell); |
| 167 | + } |
| 168 | + |
| 169 | + std::vector<torch::jit::IValue> inputs; |
| 170 | + inputs.emplace_back(model_input); |
| 171 | + |
| 172 | + torch::Tensor ec; |
| 173 | + try |
| 174 | + { |
| 175 | + ec = model_deepks.forward(inputs).toTensor(); // Hartree |
| 176 | + } |
| 177 | + catch (const c10::Error& e) |
| 178 | + { |
| 179 | + ModuleBase::WARNING_QUIT("DeePKS_domain::cal_edelta_gedm_equiv", |
| 180 | + "Failed to evaluate equivariant DeePKS model in C++."); |
| 181 | + throw; |
| 182 | + } |
161 | 183 |
|
162 | | - std::string cmd = "rm -f cal_edelta_gedm.py basis.yaml ec.npy gedm.npy"; |
163 | | - std::system(cmd.c_str()); |
| 184 | + E_delta = ec.item<double>() * 2.0; // Hartree to Ry |
| 185 | + |
| 186 | + std::vector<torch::Tensor> grad_outputs{torch::ones_like(ec)}; |
| 187 | + std::vector<torch::Tensor> grad_inputs{dm_eig}; |
| 188 | + torch::Tensor gedm_tensor = torch::autograd::grad({ec}, grad_inputs, grad_outputs, |
| 189 | + /*retain_graph=*/false, |
| 190 | + /*create_graph=*/false, |
| 191 | + /*allow_unused=*/false)[0]; |
| 192 | + |
| 193 | + torch::Tensor gedm_nat = gedm_tensor.reshape({nat, deepks_param.des_per_atom}); |
| 194 | + auto accessor = gedm_nat.accessor<double, 2>(); |
| 195 | + for (int iat = 0; iat < nat; ++iat) |
| 196 | + { |
| 197 | + for (int ides = 0; ides < deepks_param.des_per_atom; ++ides) |
| 198 | + { |
| 199 | + gedm[iat][ides] = accessor[iat][ides] * 2.0; // Hartree to Ry |
| 200 | + } |
| 201 | + } |
| 202 | + } |
| 203 | + |
| 204 | +#ifdef __MPI |
| 205 | + for (int iat = 0; iat < nat; ++iat) |
| 206 | + { |
| 207 | + MPI_Bcast(gedm[iat], deepks_param.des_per_atom, MPI_DOUBLE, 0, MPI_COMM_WORLD); |
| 208 | + } |
| 209 | + MPI_Bcast(&E_delta, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); |
| 210 | +#endif |
164 | 211 |
|
165 | 212 | ModuleBase::timer::end("DeePKS_domain", "cal_edelta_gedm_equiv"); |
166 | 213 | return; |
|
0 commit comments