Skip to content

Commit db026fe

Browse files
committed
Fix size mismatch for equivariant version DeePKS.
1 parent ef9f331 commit db026fe

File tree

4 files changed

+24
-16
lines changed

4 files changed

+24
-16
lines changed

source/module_hamilt_lcao/module_deepks/deepks_force.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ void DeePKS_domain::cal_f_delta(const std::vector<std::vector<TK>>& dm,
4949
const int nw2_tot,
5050
ModuleBase::Vector3<int> dR2)
5151
{
52-
double r0[3] = {0, 0, 0};
5352
double r1[3] = {0, 0, 0};
53+
double r2[3] = {0, 0, 0};
5454
if (isstress)
5555
{
5656
r1[0] = (tau1.x - tau0.x);
5757
r1[1] = (tau1.y - tau0.y);
5858
r1[2] = (tau1.z - tau0.z);
59-
r0[0] = (tau2.x - tau0.x);
60-
r0[1] = (tau2.y - tau0.y);
61-
r0[2] = (tau2.z - tau0.z);
59+
r2[0] = (tau2.x - tau0.x);
60+
r2[1] = (tau2.y - tau0.y);
61+
r2[2] = (tau2.z - tau0.z);
6262
}
6363

6464
auto row_indexes = pv.get_indexes_row(ibt1);
@@ -255,7 +255,7 @@ void DeePKS_domain::cal_f_delta(const std::vector<std::vector<TK>>& dm,
255255
for (int jpol = ipol; jpol < 3; jpol++)
256256
{
257257
svnl_dalpha(ipol, jpol)
258-
+= *dm_current * (nlm[ipol] * r0[jpol] + nlm_t[ipol] * r1[jpol]);
258+
+= *dm_current * (nlm[ipol] * r2[jpol] + nlm_t[ipol] * r1[jpol]);
259259
}
260260
}
261261
}

source/module_hamilt_lcao/module_deepks/deepks_pdm.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
void DeePKS_domain::read_pdm(bool read_pdm_file,
2828
bool is_equiv,
2929
bool& init_pdm,
30+
const int nat,
3031
const int inlmax,
3132
const int lmaxd,
3233
const int* inl_l,
@@ -68,9 +69,9 @@ void DeePKS_domain::read_pdm(bool read_pdm_file,
6869
nproj += (2 * il + 1) * alpha.getNchi(il);
6970
}
7071
pdm_size = nproj * nproj;
71-
for (int inl = 0; inl < inlmax; inl++)
72+
for (int iat = 0; iat < nat; iat++)
7273
{
73-
auto accessor = pdm[inl].accessor<double, 1>();
74+
auto accessor = pdm[iat].accessor<double, 1>();
7475
for (int ind = 0; ind < pdm_size; ind++)
7576
{
7677
double c;
@@ -128,9 +129,9 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
128129
nproj += (2 * il + 1) * orb.Alpha[0].getNchi(il);
129130
}
130131
pdm_size = nproj * nproj;
131-
for (int inl = 0; inl < inlmax; inl++)
132+
for (int iat = 0; iat < ucell.nat; iat++)
132133
{
133-
pdm[inl] = torch::zeros({pdm_size}, torch::kFloat64);
134+
pdm[iat] = torch::zeros({pdm_size}, torch::kFloat64);
134135
}
135136
}
136137

@@ -268,7 +269,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
268269
s_2t[i * col_size + icol] = col_ptr->get_value(col_indexes[icol], trace_alpha_col[i]);
269270
}
270271
}
271-
// prepare DM_gamma from DMR
272+
// prepare DM from DMR
272273
std::vector<double> dm_array(row_size * col_size, 0.0);
273274
const double* dm_current = nullptr;
274275
for (int is = 0; is < dm->get_DMR_vector().size(); is++)
@@ -286,6 +287,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
286287
dRy = dR2.y - dR1.y;
287288
dRz = dR2.z - dR1.z;
288289
}
290+
// dm_R
289291
auto* tmp = dm->get_DMR_vector()[is]->find_matrix(ibt1, ibt2, dRx, dRy, dRz);
290292
if (tmp == nullptr)
291293
{
@@ -306,7 +308,10 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
306308
}
307309

308310
dm_current = dm_array.data();
309-
// dgemm for s_2t and dm_current to get g_1dmt
311+
// use s_2t and dm_current to get g_1dmt
312+
// dgemm_: C = alpha * A * B + beta * C
313+
// C = g_1dmt, A = dm_current, B = s_2t
314+
// all the input should be data pointer
310315
constexpr char transa = 'T', transb = 'N';
311316
const double gemm_alpha = 1.0, gemm_beta = 1.0;
312317
dgemm_(&transa,
@@ -338,6 +343,8 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
338343
{
339344
for (int m2 = 0; m2 < nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
340345
{
346+
// ddot_: dot product of two vectors
347+
// inc means the increment of the index
341348
accessor[m1][m2] += ddot_(&row_size,
342349
g_1dmt.data() + index * row_size,
343350
&inc,

source/module_hamilt_lcao/module_deepks/deepks_pdm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ namespace DeePKS_domain
3232
void read_pdm(bool read_pdm_file,
3333
bool is_equiv,
3434
bool& init_pdm,
35+
const int nat,
3536
const int inlmax,
3637
const int lmaxd,
3738
const int* inl_l,

source/module_hamilt_lcao/module_deepks/deepks_spre.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@ void DeePKS_domain::cal_gdmepsl(const int lmaxd,
5757
const int nw2_tot,
5858
ModuleBase::Vector3<int> dR2)
5959
{
60-
double r0[3] = {0, 0, 0};
6160
double r1[3] = {0, 0, 0};
61+
double r2[3] = {0, 0, 0};
6262
r1[0] = (tau1.x - tau0.x);
6363
r1[1] = (tau1.y - tau0.y);
6464
r1[2] = (tau1.z - tau0.z);
65-
r0[0] = (tau2.x - tau0.x);
66-
r0[1] = (tau2.y - tau0.y);
67-
r0[2] = (tau2.z - tau0.z);
65+
r2[0] = (tau2.x - tau0.x);
66+
r2[1] = (tau2.y - tau0.y);
67+
r2[2] = (tau2.z - tau0.z);
6868
auto row_indexes = pv.get_indexes_row(ibt1);
6969
auto col_indexes = pv.get_indexes_col(ibt2);
7070
if (row_indexes.size() * col_indexes.size() == 0)
@@ -152,7 +152,7 @@ void DeePKS_domain::cal_gdmepsl(const int lmaxd,
152152
accessor[mm][inl][m2][m1]
153153
+= ucell.lat0 * *dm_current
154154
* (grad_overlap_2[jpol]->get_value(col_indexes[iw2], ib + m2)
155-
* overlap_1->get_value(row_indexes[iw1], ib + m1) * r0[ipol]);
155+
* overlap_1->get_value(row_indexes[iw1], ib + m1) * r2[ipol]);
156156
accessor[mm][inl][m2][m1]
157157
+= ucell.lat0 * *dm_current
158158
* (overlap_2->get_value(col_indexes[iw2], ib + m1)

0 commit comments

Comments
 (0)