2727void DeePKS_domain::read_pdm (bool read_pdm_file,
2828 bool is_equiv,
2929 bool & init_pdm,
30+ const int nat,
3031 const int inlmax,
3132 const int lmaxd,
3233 const int * inl_l,
@@ -68,9 +69,9 @@ void DeePKS_domain::read_pdm(bool read_pdm_file,
6869 nproj += (2 * il + 1 ) * alpha.getNchi (il);
6970 }
7071 pdm_size = nproj * nproj;
71- for (int inl = 0 ; inl < inlmax; inl ++)
72+ for (int iat = 0 ; iat < nat; iat ++)
7273 {
73- auto accessor = pdm[inl ].accessor <double , 1 >();
74+ auto accessor = pdm[iat ].accessor <double , 1 >();
7475 for (int ind = 0 ; ind < pdm_size; ind++)
7576 {
7677 double c;
@@ -128,9 +129,9 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
128129 nproj += (2 * il + 1 ) * orb.Alpha [0 ].getNchi (il);
129130 }
130131 pdm_size = nproj * nproj;
131- for (int inl = 0 ; inl < inlmax; inl ++)
132+ for (int iat = 0 ; iat < ucell. nat ; iat ++)
132133 {
133- pdm[inl ] = torch::zeros ({pdm_size}, torch::kFloat64 );
134+ pdm[iat ] = torch::zeros ({pdm_size}, torch::kFloat64 );
134135 }
135136 }
136137
@@ -268,7 +269,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
268269 s_2t[i * col_size + icol] = col_ptr->get_value (col_indexes[icol], trace_alpha_col[i]);
269270 }
270271 }
271- // prepare DM_gamma from DMR
272+ // prepare DM from DMR
272273 std::vector<double > dm_array (row_size * col_size, 0.0 );
273274 const double * dm_current = nullptr ;
274275 for (int is = 0 ; is < dm->get_DMR_vector ().size (); is++)
@@ -286,6 +287,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
286287 dRy = dR2.y - dR1.y ;
287288 dRz = dR2.z - dR1.z ;
288289 }
290+ // dm_R
289291 auto * tmp = dm->get_DMR_vector ()[is]->find_matrix (ibt1, ibt2, dRx, dRy, dRz);
290292 if (tmp == nullptr )
291293 {
@@ -306,7 +308,10 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
306308 }
307309
308310 dm_current = dm_array.data ();
309- // dgemm for s_2t and dm_current to get g_1dmt
311+ // use s_2t and dm_current to get g_1dmt
312+ // dgemm_: C = alpha * A * B + beta * C
313+ // C = g_1dmt, A = dm_current, B = s_2t
314+ // all the input should be data pointer
310315 constexpr char transa = ' T' , transb = ' N' ;
311316 const double gemm_alpha = 1.0 , gemm_beta = 1.0 ;
312317 dgemm_ (&transa,
@@ -338,6 +343,8 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
338343 {
339344 for (int m2 = 0 ; m2 < nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
340345 {
346+ // ddot_: dot product of two vectors
347+ // inc means the increment of the index
341348 accessor[m1][m2] += ddot_ (&row_size,
342349 g_1dmt.data () + index * row_size,
343350 &inc,
0 commit comments