Skip to content

Commit 77c253d

Browse files
committed
Simplify spre.
1 parent 327707b commit 77c253d

File tree

4 files changed

+24
-40
lines changed

4 files changed

+24
-40
lines changed

source/source_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,10 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
241241
&& !PARAM.inp.deepks_equiv) // training with stress label not supported by equivariant version now
242242
{
243243
torch::Tensor gdmepsl;
244-
DeePKS_domain::cal_gdmepsl<
245-
TK>(lmaxd, inlmax, nks, kvec_d, phialpha, inl_index, dmr, ucell, orb, *ParaV, GridD, gdmepsl);
244+
DeePKS_domain::cal_gdmepsl<TK>(nks, deepks_param, kvec_d, phialpha, dmr, ucell, orb, *ParaV, GridD, gdmepsl);
246245

247246
torch::Tensor gvepsl;
248-
DeePKS_domain::cal_gvepsl(ucell.nat, inlmax, des_per_atom, inl2l, gevdm, gdmepsl, gvepsl, rank);
247+
DeePKS_domain::cal_gvepsl(ucell.nat, deepks_param, gevdm, gdmepsl, gvepsl, rank);
249248
const std::string file_gvepsl = get_filename("gvepsl", PARAM.inp.deepks_out_labels, iter);
250249
LCAO_deepks_io::save_tensor2npy<double>(file_gvepsl, gvepsl, rank);
251250

source/source_lcao/module_deepks/deepks_spre.cpp

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
/// gdmepsl = d/d\epsilon_{ab} *
1616
/// sum_{mu,nu} rho_{mu,nu} <chi_mu|alpha_m><alpha_m'|chi_nu>
1717
template <typename TK>
18-
void DeePKS_domain::cal_gdmepsl(const int lmaxd,
19-
const int inlmax,
20-
const int nks,
18+
void DeePKS_domain::cal_gdmepsl(const int nks,
19+
const DeePKS_Param& deepks_param,
2120
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
2221
std::vector<hamilt::HContainer<double>*> phialpha,
23-
const ModuleBase::IntArray* inl_index,
2422
const hamilt::HContainer<double>* dmr,
2523
const UnitCell& ucell,
2624
const LCAO_Orbitals& orb,
@@ -33,10 +31,10 @@ void DeePKS_domain::cal_gdmepsl(const int lmaxd,
3331
// get DS_alpha_mu and S_nu_beta
3432

3533
int nrow = pv.nrow;
36-
const int nm = 2 * lmaxd + 1;
34+
const int nm = 2 * deepks_param.lmaxd + 1;
3735
// gdmepsl: dD/d\epsilon_{\alpha\beta}
3836
// size: [6][tot_Inl][2l+1][2l+1]
39-
gdmepsl = torch::zeros({6, inlmax, nm, nm}, torch::dtype(torch::kFloat64));
37+
gdmepsl = torch::zeros({6, deepks_param.inlmax, nm, nm}, torch::dtype(torch::kFloat64));
4038
auto accessor = gdmepsl.accessor<double, 4>();
4139

4240
DeePKS_domain::iterate_ad2(
@@ -111,7 +109,7 @@ void DeePKS_domain::cal_gdmepsl(const int lmaxd,
111109
{
112110
for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0)
113111
{
114-
const int inl = inl_index[ucell.iat2it[iat]](ucell.iat2ia[iat], L0, N0);
112+
const int inl = deepks_param.inl_index[ucell.iat2it[iat]](ucell.iat2ia[iat], L0, N0);
115113
const int nm = 2 * L0 + 1;
116114
for (int m1 = 0; m1 < nm; ++m1)
117115
{
@@ -147,7 +145,7 @@ void DeePKS_domain::cal_gdmepsl(const int lmaxd,
147145
);
148146

149147
#ifdef __MPI
150-
Parallel_Reduce::reduce_all(gdmepsl.data_ptr<double>(), 6 * inlmax * nm * nm);
148+
Parallel_Reduce::reduce_all(gdmepsl.data_ptr<double>(), 6 * deepks_param.inlmax * nm * nm);
151149
#endif
152150
ModuleBase::timer::tick("DeePKS_domain", "cal_gdmepsl");
153151
return;
@@ -156,9 +154,7 @@ void DeePKS_domain::cal_gdmepsl(const int lmaxd,
156154
// calculates stress of descriptors from gradient of projected density matrices
157155
// gv_epsl:d(d)/d\epsilon_{\alpha\beta}, [natom][6][des_per_atom]
158156
void DeePKS_domain::cal_gvepsl(const int nat,
159-
const int inlmax,
160-
const int des_per_atom,
161-
const std::vector<int>& inl2l,
157+
const DeePKS_Param& deepks_param,
162158
const std::vector<torch::Tensor>& gevdm,
163159
const torch::Tensor& gdmepsl,
164160
torch::Tensor& gvepsl,
@@ -172,11 +168,11 @@ void DeePKS_domain::cal_gvepsl(const int nat,
172168
if (rank == 0)
173169
{
174170
// make gdmepsl as tensor
175-
int nlmax = inlmax / nat;
171+
int nlmax = deepks_param.inlmax / nat;
176172
for (int nl = 0; nl < nlmax; ++nl)
177173
{
178-
int nm = 2 * inl2l[nl] + 1;
179-
torch::Tensor gdmepsl_sliced = gdmepsl.slice(1, nl, inlmax, nlmax).slice(2, 0, nm, 1).slice(3, 0, nm, 1);
174+
int nm = 2 * deepks_param.inl2l[nl] + 1;
175+
torch::Tensor gdmepsl_sliced = gdmepsl.slice(1, nl, deepks_param.inlmax, nlmax).slice(2, 0, nm, 1).slice(3, 0, nm, 1);
180176
gdmepsl_vector.push_back(gdmepsl_sliced);
181177
}
182178
assert(gdmepsl_vector.size() == nlmax);
@@ -197,32 +193,28 @@ void DeePKS_domain::cal_gvepsl(const int nat,
197193
gvepsl = torch::cat(gvepsl_vector, -1);
198194
assert(gvepsl.size(0) == 6);
199195
assert(gvepsl.size(1) == nat);
200-
assert(gvepsl.size(2) == des_per_atom);
196+
assert(gvepsl.size(2) == deepks_param.des_per_atom);
201197
}
202198

203199
ModuleBase::timer::tick("DeePKS_domain", "cal_gvepsl");
204200
return;
205201
}
206202

207-
template void DeePKS_domain::cal_gdmepsl<double>(const int lmaxd,
208-
const int inlmax,
209-
const int nks,
203+
template void DeePKS_domain::cal_gdmepsl<double>(const int nks,
204+
const DeePKS_Param& deepks_param,
210205
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
211206
std::vector<hamilt::HContainer<double>*> phialpha,
212-
const ModuleBase::IntArray* inl_index,
213207
const hamilt::HContainer<double>* dmr,
214208
const UnitCell& ucell,
215209
const LCAO_Orbitals& orb,
216210
const Parallel_Orbitals& pv,
217211
const Grid_Driver& GridD,
218212
torch::Tensor& gdmepsl);
219213

220-
template void DeePKS_domain::cal_gdmepsl<std::complex<double>>(const int lmaxd,
221-
const int inlmax,
222-
const int nks,
214+
template void DeePKS_domain::cal_gdmepsl<std::complex<double>>(const int nks,
215+
const DeePKS_Param& deepks_param,
223216
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
224217
std::vector<hamilt::HContainer<double>*> phialpha,
225-
const ModuleBase::IntArray* inl_index,
226218
const hamilt::HContainer<double>* dmr,
227219
const UnitCell& ucell,
228220
const LCAO_Orbitals& orb,

source/source_lcao/module_deepks/deepks_spre.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#ifdef __MLALGO
55

6+
#include "deepks_param.h"
67
#include "source_base/complexmatrix.h"
78
#include "source_base/intarray.h"
89
#include "source_base/matrix.h"
@@ -31,13 +32,11 @@ namespace DeePKS_domain
3132
// calculate the gradient of pdm with regard to atomic virial stress tensor
3233
// d/d\epsilon D_{Inl,mm'}
3334
template <typename TK>
34-
void cal_gdmepsl( // const ModuleBase::matrix& dm,
35-
const int lmaxd,
36-
const int inlmax,
35+
void cal_gdmepsl(
3736
const int nks,
37+
const DeePKS_Param& deepks_param,
3838
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
3939
std::vector<hamilt::HContainer<double>*> phialpha,
40-
const ModuleBase::IntArray* inl_index,
4140
const hamilt::HContainer<double>* dmr,
4241
const UnitCell& ucell,
4342
const LCAO_Orbitals& orb,
@@ -46,9 +45,7 @@ void cal_gdmepsl( // const ModuleBase::matrix& dm,
4645
torch::Tensor& gdmepsl);
4746

4847
void cal_gvepsl(const int nat,
49-
const int inlmax,
50-
const int des_per_atom,
51-
const std::vector<int>& inl2l,
48+
const DeePKS_Param& deepks_param,
5249
const std::vector<torch::Tensor>& gevdm,
5350
const torch::Tensor& gdmepsl,
5451
torch::Tensor& gvepsl,

source/source_lcao/module_deepks/test/LCAO_deepks_test.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,10 @@ void test_deepks<T>::check_gvx(torch::Tensor& gdmx)
201201
template <typename T>
202202
void test_deepks<T>::check_gdmepsl(torch::Tensor& gdmepsl)
203203
{
204-
DeePKS_domain::cal_gdmepsl<T>(this->ld.lmaxd,
205-
this->ld.inlmax,
206-
kv.nkstot,
204+
DeePKS_domain::cal_gdmepsl<T>(kv.nkstot,
205+
this->ld.deepks_param,
207206
kv.kvec_d,
208207
this->ld.phialpha,
209-
this->ld.inl_index,
210208
this->ld.dm_r,
211209
ucell,
212210
ORB,
@@ -224,9 +222,7 @@ void test_deepks<T>::check_gvepsl(torch::Tensor& gdmepsl)
224222
DeePKS_domain::cal_gevdm(ucell.nat, this->ld.deepks_param, this->ld.pdm, gevdm);
225223
torch::Tensor gvepsl;
226224
DeePKS_domain::cal_gvepsl(ucell.nat,
227-
this->ld.inlmax,
228-
this->ld.des_per_atom,
229-
this->ld.inl2l,
225+
this->ld.deepks_param,
230226
gevdm,
231227
gdmepsl,
232228
gvepsl,

0 commit comments

Comments
 (0)