Skip to content

Commit ab2e27a

Browse files
authored
Refactor: Simplify some functions in DeePKS. (#6102)
* Simplify some functions in DeePKS. * clang-format change.
1 parent b930866 commit ab2e27a

File tree

10 files changed

+92
-240
lines changed

10 files changed

+92
-240
lines changed

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class LCAO_Deepks
7373
int inlmax = 0; // tot. number {i,n,l} - atom, n, l
7474
int n_descriptor; // natoms * des_per_atom, size of descriptor(projector) basis set
7575
int des_per_atom; // \sum_L{Nchi(L)*(2L+1)}
76-
std::vector<int> inl2l; // inl2l[inl] = l of descriptor with inl_index
76+
std::vector<int> inl2l; // inl2l[inl] = inl2l[nl] = l (not related to iat) of descriptor with inl_index
7777
ModuleBase::IntArray* inl_index; // caoyu add 2021-05-07
7878

7979
bool init_pdm = false; // for DeePKS NSCF calculation, set init_pdm to skip the calculation of pdm in SCF iteration

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
6464
{
6565
// this part is for integrated test of deepks
6666
// so it is printed no matter even if deepks_out_labels is not used
67-
DeePKS_domain::update_dmr(kvec_d, dm->get_DMK_vector(), ucell, orb, *ParaV, GridD, dmr);
68-
6967
DeePKS_domain::cal_pdm<
7068
TK>(init_pdm, inlmax, lmaxd, inl2l, inl_index, kvec_d, dmr, phialpha, ucell, orb, GridD, *ParaV, pdm);
7169

source/module_hamilt_lcao/module_deepks/deepks_force.cpp

Lines changed: 13 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ void DeePKS_domain::cal_f_delta(const hamilt::HContainer<double>* dmr,
151151
+= gedm[inl][m1 * nm + m2]
152152
* overlap_1->get_value(row_indexes[iw1], ib + m1)
153153
* grad_overlap_2[dim]->get_value(col_indexes[iw2], ib + m2);
154+
if (isstress)
155+
{
156+
nlm_t[dim] += gedm[inl][m1 * nm + m2]
157+
* overlap_2->get_value(col_indexes[iw2], ib + m1)
158+
* grad_overlap_1[dim]->get_value(row_indexes[iw1],
159+
ib + m2);
160+
}
154161
}
155162
}
156163
}
@@ -175,6 +182,12 @@ void DeePKS_domain::cal_f_delta(const hamilt::HContainer<double>* dmr,
175182
nlm[dim] += gedm[iat][iproj * nproj + jproj]
176183
* overlap_1->get_value(row_indexes[iw1], iproj)
177184
* grad_overlap_2[dim]->get_value(col_indexes[iw2], jproj);
185+
if (isstress)
186+
{
187+
nlm_t[dim] += gedm[iat][iproj * nproj + jproj]
188+
* overlap_2->get_value(col_indexes[iw2], iproj)
189+
* grad_overlap_1[dim]->get_value(row_indexes[iw1], jproj);
190+
}
178191
}
179192
}
180193
}
@@ -192,54 +205,6 @@ void DeePKS_domain::cal_f_delta(const hamilt::HContainer<double>* dmr,
192205

193206
if (isstress)
194207
{
195-
if (!PARAM.inp.deepks_equiv)
196-
{
197-
int ib = 0;
198-
for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0)
199-
{
200-
for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0)
201-
{
202-
const int inl = inl_index[T0](I0, L0, N0);
203-
const int nm = 2 * L0 + 1;
204-
for (int m1 = 0; m1 < nm; ++m1)
205-
{
206-
for (int m2 = 0; m2 < nm; ++m2)
207-
{
208-
for (int dim = 0; dim < 3; ++dim)
209-
{
210-
nlm_t[dim] += gedm[inl][m1 * nm + m2]
211-
* overlap_2->get_value(col_indexes[iw2], ib + m1)
212-
* grad_overlap_1[dim]->get_value(row_indexes[iw1],
213-
ib + m2);
214-
}
215-
}
216-
}
217-
ib += nm;
218-
}
219-
}
220-
assert(ib == overlap_2->get_col_size());
221-
}
222-
else
223-
{
224-
int nproj = 0;
225-
for (int il = 0; il < lmaxd + 1; il++)
226-
{
227-
nproj += (2 * il + 1) * orb.Alpha[0].getNchi(il);
228-
}
229-
for (int iproj = 0; iproj < nproj; iproj++)
230-
{
231-
for (int jproj = 0; jproj < nproj; jproj++)
232-
{
233-
for (int dim = 0; dim < 3; dim++)
234-
{
235-
nlm_t[dim] += gedm[iat][iproj * nproj + jproj]
236-
* overlap_2->get_value(col_indexes[iw2], iproj)
237-
* grad_overlap_1[dim]->get_value(row_indexes[iw1], jproj);
238-
}
239-
}
240-
}
241-
}
242-
243208
for (int ipol = 0; ipol < 3; ipol++)
244209
{
245210
for (int jpol = ipol; jpol < 3; jpol++)

source/module_hamilt_lcao/module_deepks/deepks_fpre.cpp

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -211,36 +211,10 @@ void DeePKS_domain::cal_gvx(const int nat,
211211
int nlmax = inlmax / nat;
212212
for (int nl = 0; nl < nlmax; ++nl)
213213
{
214-
std::vector<torch::Tensor> bmmv;
215-
for (int ibt = 0; ibt < nat; ++ibt)
216-
{
217-
std::vector<torch::Tensor> xmmv;
218-
for (int i = 0; i < 3; ++i)
219-
{
220-
std::vector<torch::Tensor> ammv;
221-
for (int iat = 0; iat < nat; ++iat)
222-
{
223-
int inl = iat * nlmax + nl;
224-
int nm = 2 * inl2l[inl] + 1;
225-
std::vector<double> mmv;
226-
for (int m1 = 0; m1 < nm; ++m1)
227-
{
228-
for (int m2 = 0; m2 < nm; ++m2)
229-
{
230-
mmv.push_back(accessor[i][ibt][inl][m1][m2]);
231-
}
232-
} // nm^2
233-
torch::Tensor mm = torch::tensor(mmv, torch::TensorOptions().dtype(torch::kFloat64))
234-
.reshape({nm, nm}); // nm*nm
235-
ammv.push_back(mm);
236-
}
237-
torch::Tensor amm = torch::stack(ammv, 0); // nat*nm*nm
238-
xmmv.push_back(amm);
239-
}
240-
torch::Tensor bmm = torch::stack(xmmv, 0); // 3*nat*nm*nm
241-
bmmv.push_back(bmm);
242-
}
243-
gdmr.push_back(torch::stack(bmmv, 0)); // nbt*3*nat*nm*nm
214+
int nm = 2 * inl2l[nl] + 1;
215+
torch::Tensor gdmx_sliced
216+
= gdmx.slice(2, nl, inlmax, nlmax).slice(3, 0, nm, 1).slice(4, 0, nm, 1).permute({1, 0, 2, 3, 4});
217+
gdmr.push_back(gdmx_sliced);
244218
}
245219

246220
assert(gdmr.size() == nlmax);

source/module_hamilt_lcao/module_deepks/deepks_iterate.cpp

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,34 @@ void DeePKS_domain::iterate_ad1(const UnitCell& ucell,
1616
ModuleBase::Vector3<int> /*dR*/)> callback)
1717
{
1818
const double Rcut_Alpha = orb.Alpha[0].getRcut();
19-
for (int T0 = 0; T0 < ucell.ntype; T0++)
19+
for (int iat = 0; iat < ucell.nat; iat++)
2020
{
21+
const int T0 = ucell.iat2it[iat];
22+
const int I0 = ucell.iat2ia[iat];
2123
Atom* atom0 = &ucell.atoms[T0];
22-
for (int I0 = 0; I0 < atom0->na; I0++)
24+
const ModuleBase::Vector3<double> tau0 = atom0->tau[I0];
25+
GridD.Find_atom(ucell, tau0, T0, I0);
26+
for (int ad = 0; ad < GridD.getAdjacentNum() + 1; ++ad)
2327
{
24-
const int iat = ucell.itia2iat(T0, I0);
25-
const ModuleBase::Vector3<double> tau0 = atom0->tau[I0];
26-
GridD.Find_atom(ucell, tau0, T0, I0);
27-
for (int ad = 0; ad < GridD.getAdjacentNum() + 1; ++ad)
28+
const int T1 = GridD.getType(ad);
29+
const int I1 = GridD.getNatom(ad);
30+
const int ibt = ucell.itia2iat(T1, I1);
31+
const int start = ucell.itiaiw2iwt(T1, I1, 0);
32+
33+
const ModuleBase::Vector3<double> tau1 = GridD.getAdjacentTau(ad);
34+
const Atom* atom1 = &ucell.atoms[T1];
35+
const int nw1_tot = atom1->nw * PARAM.globalv.npol;
36+
const double Rcut_AO1 = orb.Phi[T1].getRcut();
37+
const double dist1 = (tau1 - tau0).norm() * ucell.lat0;
38+
39+
if (dist1 > Rcut_Alpha + Rcut_AO1)
2840
{
29-
const int T1 = GridD.getType(ad);
30-
const int I1 = GridD.getNatom(ad);
31-
const int ibt = ucell.itia2iat(T1, I1); // on which chi_mu is located
32-
const int start = ucell.itiaiw2iwt(T1, I1, 0);
33-
34-
const ModuleBase::Vector3<double> tau1 = GridD.getAdjacentTau(ad);
35-
const Atom* atom1 = &ucell.atoms[T1];
36-
const int nw1_tot = atom1->nw * PARAM.globalv.npol;
37-
const double Rcut_AO1 = orb.Phi[T1].getRcut();
38-
const double dist1 = (tau1 - tau0).norm() * ucell.lat0;
39-
40-
if (dist1 > Rcut_Alpha + Rcut_AO1)
41-
{
42-
continue;
43-
}
41+
continue;
42+
}
4443

45-
ModuleBase::Vector3<int> dR(GridD.getBox(ad).x, GridD.getBox(ad).y, GridD.getBox(ad).z);
44+
ModuleBase::Vector3<int> dR(GridD.getBox(ad).x, GridD.getBox(ad).y, GridD.getBox(ad).z);
4645

47-
callback(iat, tau0, ibt, tau1, start, nw1_tot, dR);
48-
}
46+
callback(iat, tau0, ibt, tau1, start, nw1_tot, dR);
4947
}
5048
}
5149
}
@@ -174,7 +172,8 @@ void DeePKS_domain::iterate_ad2(const UnitCell& ucell,
174172

175173
callback(iat, tau0, ibt1, tau1, start1, nw1_tot, dR1, ibt2, tau2, start2, nw2_tot, dR2);
176174
}
177-
});
175+
}
176+
);
178177
}
179178

180179
#endif

source/module_hamilt_lcao/module_deepks/deepks_orbpre.cpp

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,12 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector<TH>& dm_hl,
187187
for (int ik = 0; ik < dm_hl.size(); ik++)
188188
{
189189
dm_pair.allocate(&dm_array[ik * row_size * col_size], 0);
190-
190+
191191
std::complex<double> kphase = std::complex<double>(1, 0);
192192
if (std::is_same<TK, std::complex<double>>::value)
193193
{
194-
const double arg = -(kvec_d[ik] * ModuleBase::Vector3<double>(dR1 - dR2)) * ModuleBase::TWO_PI;
194+
const double arg
195+
= -(kvec_d[ik] * ModuleBase::Vector3<double>(dR1 - dR2)) * ModuleBase::TWO_PI;
195196
kphase = std::complex<double>(cos(arg), sin(arg));
196197
}
197198
TK* kphase_ptr = reinterpret_cast<TK*>(&kphase);
@@ -274,33 +275,10 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector<TH>& dm_hl,
274275
std::vector<torch::Tensor> orbital_pdm_vector;
275276
for (int nl = 0; nl < nlmax; ++nl)
276277
{
277-
std::vector<torch::Tensor> kammv;
278-
for (int iks = 0; iks < nks; ++iks)
279-
{
280-
std::vector<torch::Tensor> ammv;
281-
for (int iat = 0; iat < nat; ++iat)
282-
{
283-
int inl = iat * nlmax + nl;
284-
int nm = 2 * inl2l[inl] + 1;
285-
std::vector<double> mmv;
286-
287-
for (int m1 = 0; m1 < nm; ++m1) // m1 = 1 for s, 3 for p, 5 for d
288-
{
289-
for (int m2 = 0; m2 < nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
290-
{
291-
mmv.push_back(accessor[iks][inl][m1][m2]);
292-
}
293-
}
294-
torch::Tensor mm
295-
= torch::tensor(mmv, torch::TensorOptions().dtype(torch::kFloat64)).reshape({nm, nm}); // nm*nm
296-
297-
ammv.push_back(mm);
298-
}
299-
torch::Tensor amm = torch::stack(ammv, 0);
300-
kammv.push_back(amm);
301-
}
302-
torch::Tensor kamm = torch::stack(kammv, 0);
303-
orbital_pdm_vector.push_back(kamm);
278+
int nm = 2 * inl2l[nl] + 1;
279+
torch::Tensor orbital_pdm_sliced
280+
= orbital_pdm.slice(1, nl, inlmax, nlmax).slice(2, 0, nm, 1).slice(3, 0, nm, 1);
281+
orbital_pdm_vector.push_back(orbital_pdm_sliced);
304282
}
305283

306284
assert(orbital_pdm_vector.size() == nlmax);

source/module_hamilt_lcao/module_deepks/deepks_pdm.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ void DeePKS_domain::update_dmr(const std::vector<ModuleBase::Vector3<double>>& k
9696
hamilt::HContainer<double>* dmr_deepks)
9797
{
9898
dmr_deepks->set_zero();
99+
// save whether the pair with R has been calculated
100+
std::vector<std::tuple<int, int, int, int, int>> calculated_pairs(0);
101+
99102
DeePKS_domain::iterate_ad2(
100103
ucell,
101104
GridD,
@@ -134,6 +137,15 @@ void DeePKS_domain::update_dmr(const std::vector<ModuleBase::Vector3<double>>& k
134137
}
135138
ModuleBase::Vector3<int> dR(dRx, dRy, dRz);
136139

140+
// avoid duplicate calculation
141+
if (std::find(calculated_pairs.begin(), calculated_pairs.end(),
142+
std::make_tuple(ibt1, ibt2, dR.x, dR.y, dR.z))
143+
!= calculated_pairs.end())
144+
{
145+
return;
146+
}
147+
calculated_pairs.push_back(std::make_tuple(ibt1, ibt2, dR.x, dR.y, dR.z));
148+
137149
dm_pair.find_R(dR);
138150
hamilt::BaseMatrix<double>* dmr_ptr = dm_pair.find_matrix(dR);
139151
dmr_ptr->set_zero(); // must reset to zero to avoid accumulation!
@@ -222,7 +234,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
222234
Atom* atom0 = &ucell.atoms[T0];
223235
const ModuleBase::Vector3<double> tau0 = atom0->tau[I0];
224236
AdjacentAtomInfo adjs;
225-
GridD.Find_atom(ucell, atom0->tau[I0], T0, I0, &adjs);
237+
GridD.Find_atom(ucell, tau0, T0, I0, &adjs);
226238

227239
// trace alpha orbital
228240
std::vector<int> trace_alpha_row;

source/module_hamilt_lcao/module_deepks/deepks_spre.cpp

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -202,34 +202,13 @@ void DeePKS_domain::cal_gvepsl(const int nat,
202202
auto accessor = gdmepsl.accessor<double, 4>();
203203
if (rank == 0)
204204
{
205-
// make gdmx as tensor
205+
// make gdmepsl as tensor
206206
int nlmax = inlmax / nat;
207207
for (int nl = 0; nl < nlmax; ++nl)
208208
{
209-
std::vector<torch::Tensor> bmmv;
210-
for (int i = 0; i < 6; ++i)
211-
{
212-
std::vector<torch::Tensor> ammv;
213-
for (int iat = 0; iat < nat; ++iat)
214-
{
215-
int inl = iat * nlmax + nl;
216-
int nm = 2 * inl2l[inl] + 1;
217-
std::vector<double> mmv;
218-
for (int m1 = 0; m1 < nm; ++m1)
219-
{
220-
for (int m2 = 0; m2 < nm; ++m2)
221-
{
222-
mmv.push_back(accessor[i][inl][m1][m2]);
223-
}
224-
} // nm^2
225-
torch::Tensor mm
226-
= torch::tensor(mmv, torch::TensorOptions().dtype(torch::kFloat64)).reshape({nm, nm}); // nm*nm
227-
ammv.push_back(mm);
228-
}
229-
torch::Tensor bmm = torch::stack(ammv, 0); // nat*nm*nm
230-
bmmv.push_back(bmm);
231-
}
232-
gdmepsl_vector.push_back(torch::stack(bmmv, 0)); // nbt*3*nat*nm*nm
209+
int nm = 2 * inl2l[nl] + 1;
210+
torch::Tensor gdmepsl_sliced = gdmepsl.slice(1, nl, inlmax, nlmax).slice(2, 0, nm, 1).slice(3, 0, nm, 1);
211+
gdmepsl_vector.push_back(gdmepsl_sliced);
233212
}
234213
assert(gdmepsl_vector.size() == nlmax);
235214

source/module_hamilt_lcao/module_deepks/deepks_vdelta.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,6 @@ void DeePKS_domain::collect_h_mat(const Parallel_Orbitals& pv,
106106
}
107107
}
108108
}
109-
else
110-
{
111-
// do nothing
112-
}
113109

114110
Parallel_Reduce::reduce_all(lineH.data(), nlocal - i);
115111

@@ -146,12 +142,11 @@ template void DeePKS_domain::cal_e_delta_band<std::complex<double>>(
146142
const Parallel_Orbitals* pv,
147143
double& e_delta_band);
148144

149-
template void DeePKS_domain::collect_h_mat<double, ModuleBase::matrix>(
150-
const Parallel_Orbitals& pv,
151-
const std::vector<std::vector<double>>& h_in,
152-
std::vector<ModuleBase::matrix>& h_out,
153-
const int nlocal,
154-
const int nks);
145+
template void DeePKS_domain::collect_h_mat<double, ModuleBase::matrix>(const Parallel_Orbitals& pv,
146+
const std::vector<std::vector<double>>& h_in,
147+
std::vector<ModuleBase::matrix>& h_out,
148+
const int nlocal,
149+
const int nks);
155150

156151
template void DeePKS_domain::collect_h_mat<std::complex<double>, ModuleBase::ComplexMatrix>(
157152
const Parallel_Orbitals& pv,

0 commit comments

Comments
 (0)