@@ -93,6 +93,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
9393 const int lmaxd,
9494 const std::vector<int >& inl2l,
9595 const ModuleBase::IntArray* inl_index,
96+ const std::vector<ModuleBase::Vector3<double >>& kvec_d,
9697 const elecstate::DensityMatrix<TK, double >* dm,
9798 const std::vector<hamilt::HContainer<double >*> phialpha,
9899 const UnitCell& ucell,
@@ -231,7 +232,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
231232 }
232233 }
233234
234- for (int ad2 = 0 ; ad2 < adjs.adj_num + 1 ; ad2++)
235+ for (int ad2 = 0 ; ad2 < adjs.adj_num + 1 ; ad2++)
235236 {
236237 const int T2 = adjs.ntype [ad2];
237238 const int I2 = adjs.natom [ad2];
@@ -274,33 +275,31 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
274275 // prepare DM from DMR
275276 std::vector<double > dm_array (row_size * col_size, 0.0 );
276277 const double * dm_current = nullptr ;
277- for (int is = 0 ; is < dm->get_DMR_vector ().size (); is++)
278+ int dRx = 0 , dRy = 0 , dRz = 0 ;
279+ if constexpr (std::is_same<TK, std::complex <double >>::value)
278280 {
279- int dRx = 0 , dRy = 0 , dRz = 0 ;
280- if constexpr (std::is_same<TK, std::complex <double >>::value)
281- {
282- dRx = dR2.x - dR1.x ;
283- dRy = dR2.y - dR1.y ;
284- dRz = dR2.z - dR1.z ;
285- }
286- // dm_R
287- auto * tmp = dm->get_DMR_vector ()[is]->find_matrix (ibt1, ibt2, dRx, dRy, dRz);
288- if (tmp == nullptr )
289- {
290- // in case of no deepks_scf but out_deepks_label, size of DMR would mismatch with
291- // deepks-orbitals
292- dm_current = nullptr ;
293- break ;
294- }
295- dm_current = tmp->get_pointer ();
296- for (int idm = 0 ; idm < row_size * col_size; idm++)
297- {
298- dm_array[idm] += dm_current[idm];
299- }
281+ dRx = dR2.x - dR1.x ;
282+ dRy = dR2.y - dR1.y ;
283+ dRz = dR2.z - dR1.z ;
300284 }
301- if (dm_current == nullptr )
285+ // dm_k
286+ auto dm_k = dm->get_DMK_vector ();
287+ const int nrow = pv.nrow ;
288+ for (int ir = 0 ; ir < row_size; ir++)
302289 {
303- continue ; // skip the long range DM pair more than nonlocal term
290+ for (int ic = 0 ; ic < col_size; ic++)
291+ {
292+ int iglob = (pv.atom_begin_row [ibt1] + ir) + nrow * (pv.atom_begin_col [ibt2] + ic);
293+ int iloc = ir * col_size + ic;
294+ std::complex <double > tmp = 0.0 ;
295+ for (int ik = 0 ; ik < dm_k.size (); ik++) // dm_k.size() == _nk * _nspin
296+ {
297+ const double arg = (kvec_d[ik] * ModuleBase::Vector3<double >(dR1 - dR2)) * ModuleBase::TWO_PI;
298+ const std::complex <double > kphase = std::complex <double >(cos (arg), sin (arg));
299+ tmp += dm_k[ik][iglob] * kphase;
300+ }
301+ dm_array[iloc] += tmp.real ();
302+ }
304303 }
305304
306305 dm_current = dm_array.data ();
@@ -311,18 +310,18 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
311310 constexpr char transa = ' T' , transb = ' N' ;
312311 const double gemm_alpha = 1.0 , gemm_beta = 1.0 ;
313312 dgemm_ (&transa,
314- &transb,
315- &row_size,
316- &trace_alpha_size,
317- &col_size,
318- &gemm_alpha,
319- dm_current,
320- &col_size,
321- s_2t.data (),
322- &col_size,
323- &gemm_beta,
324- g_1dmt.data (),
325- &row_size);
313+ &transb,
314+ &row_size,
315+ &trace_alpha_size,
316+ &col_size,
317+ &gemm_alpha,
318+ dm_current,
319+ &col_size,
320+ s_2t.data (),
321+ &col_size,
322+ &gemm_beta,
323+ g_1dmt.data (),
324+ &row_size);
326325 } // ad2
327326 if (!PARAM.inp .deepks_equiv )
328327 {
@@ -340,10 +339,10 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
340339 for (int m2 = 0 ; m2 < nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
341340 {
342341 accessor[m1][m2] += ddot_ (&row_size,
343- g_1dmt.data () + index * row_size,
344- &inc,
345- s_1t.data () + index * row_size,
346- &inc);
342+ g_1dmt.data () + index * row_size,
343+ &inc,
344+ s_1t.data () + index * row_size,
345+ &inc);
347346 index++;
348347 }
349348 }
@@ -366,10 +365,10 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
366365 // ddot_: dot product of two vectors
367366 // inc means the increment of the index
368367 accessor[iproj * nproj + jproj] += ddot_ (&row_size,
369- g_1dmt.data () + index * row_size,
370- &inc,
371- s_1t.data () + index * row_size,
372- &inc);
368+ g_1dmt.data () + index * row_size,
369+ &inc,
370+ s_1t.data () + index * row_size,
371+ &inc);
373372 index++;
374373 }
375374 }
@@ -414,6 +413,7 @@ template void DeePKS_domain::cal_pdm<double>(bool& init_pdm,
414413 const int lmaxd,
415414 const std::vector<int >& inl2l,
416415 const ModuleBase::IntArray* inl_index,
416+ const std::vector<ModuleBase::Vector3<double >>& kvec_d,
417417 const elecstate::DensityMatrix<double , double >* dm,
418418 const std::vector<hamilt::HContainer<double >*> phialpha,
419419 const UnitCell& ucell,
@@ -428,6 +428,7 @@ template void DeePKS_domain::cal_pdm<std::complex<double>>(
428428 const int lmaxd,
429429 const std::vector<int >& inl2l,
430430 const ModuleBase::IntArray* inl_index,
431+ const std::vector<ModuleBase::Vector3<double >>& kvec_d,
431432 const elecstate::DensityMatrix<std::complex <double >, double >* dm,
432433 const std::vector<hamilt::HContainer<double >*> phialpha,
433434 const UnitCell& ucell,
0 commit comments