@@ -273,36 +273,51 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
273273 }
274274 }
275275 // prepare DM from DMR
276- std::vector<double > dm_array (row_size * col_size, 0.0 );
277- const double * dm_current = nullptr ;
278276 int dRx = 0 , dRy = 0 , dRz = 0 ;
279277 if constexpr (std::is_same<TK, std::complex <double >>::value)
280278 {
281- dRx = dR2 .x - dR1 .x ;
282- dRy = dR2 .y - dR1 .y ;
283- dRz = dR2 .z - dR1 .z ;
279+ dRx = dR1 .x - dR2 .x ;
280+ dRy = dR1 .y - dR2 .y ;
281+ dRz = dR1 .z - dR2 .z ;
284282 }
285- // dm_k
283+ ModuleBase::Vector3<double > dR (dRx, dRy, dRz);
284+
285+ hamilt::AtomPair<double > dm_pair (ibt1, ibt2, dRx, dRy, dRz, &pv);
286+ dm_pair.allocate (nullptr , true );
286287 auto dm_k = dm->get_DMK_vector ();
287- const int nrow = pv.nrow ;
288- for (int ir = 0 ; ir < row_size; ir++)
288+
289+ if constexpr (std::is_same<TK, double >::value) // for gamma-only
290+ {
291+ for (int is = 0 ; is < dm_k.size (); is++)
292+ {
293+ if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER (PARAM.inp .ks_solver ))
294+ {
295+ dm_pair.add_from_matrix (dm_k[is].data (), pv.get_row_size (), 1.0 , 1 );
296+ }
297+ else
298+ {
299+ dm_pair.add_from_matrix (dm_k[is].data (), pv.get_col_size (), 1.0 , 0 );
300+ }
301+ }
302+ }
303+ else // for multi-k
289304 {
290- for (int ic = 0 ; ic < col_size; ic ++)
305+ for (int ik = 0 ; ik < dm_k. size (); ik ++)
291306 {
292- int iglob = (pv.atom_begin_row [ibt1] + ir) + nrow * (pv.atom_begin_col [ibt2] + ic);
293- int iloc = ir * col_size + ic;
294- std::complex <double > tmp = 0.0 ;
295- for (int ik = 0 ; ik < dm_k.size (); ik++) // dm_k.size() == _nk * _nspin
307+ const double arg = -(kvec_d[ik] * dR) * ModuleBase::TWO_PI;
308+ const std::complex <double > kphase = std::complex <double >(cos (arg), sin (arg));
309+ if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER (PARAM.inp .ks_solver ))
310+ {
311+ dm_pair.add_from_matrix (dm_k[ik].data (), pv.get_row_size (), kphase, 1 );
312+ }
313+ else
296314 {
297- const double arg = (kvec_d[ik] * ModuleBase::Vector3<double >(dR1 - dR2)) * ModuleBase::TWO_PI;
298- const std::complex <double > kphase = std::complex <double >(cos (arg), sin (arg));
299- tmp += dm_k[ik][iglob] * kphase;
315+ dm_pair.add_from_matrix (dm_k[ik].data (), pv.get_col_size (), kphase, 0 );
300316 }
301- dm_array[iloc] += tmp.real ();
302317 }
303318 }
304319
305- dm_current = dm_array. data ();
320+ const double * dm_current = dm_pair. get_pointer ();
306321 // use s_2t and dm_current to get g_1dmt
307322 // dgemm_: C = alpha * A * B + beta * C
308323 // C = g_1dmt, A = dm_current, B = s_2t
0 commit comments