@@ -33,10 +33,11 @@ Sto_EleCond<FPTYPE, Device>::Sto_EleCond(UnitCell* p_ucell_in,
3333 this ->nbands_ks = p_psi_in->get_nbands ();
3434 this ->nbands_sto = p_stowf_in->nchi ;
3535 this ->stofunc .set_E_range (&stoche.emin_sto , &stoche.emax_sto );
36+ this ->cond_dtbatch = PARAM.inp .cond_dtbatch ;
3637#ifdef __ENABLE_FLOAT_FFTW
3738 if (!std::is_same<FPTYPE, lowTYPE>::value)
3839 {
39- this ->hamilt_sto_ = new hamilt::HamiltSdftPW<std::complex <lowTYPE>, Device>(p_elec_in->pot , p_wfcpw_in, p_kv_in, p_ppcell_in, p_ucell_in, 1 , &this ->emin_sto_ , &this ->emax_sto_ );
40+ this ->hamilt_sto_ = new hamilt::HamiltSdftPW<std::complex <lowTYPE>, Device>(p_elec_in->pot , p_wfcpw_in, p_kv_in, p_ppcell_in, p_ucell_in, 1 , &this ->low_emin_ , &this ->low_emax_ );
4041 }
4142#endif
4243}
@@ -149,6 +150,7 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
149150 psi::Psi<std::complex <lowTYPE>, Device>& leftchi,
150151 psi::Psi<std::complex <lowTYPE>, Device>& rightchi,
151152 psi::Psi<std::complex <lowTYPE>, Device>& left_hchi,
153+ psi::Psi<std::complex <lowTYPE>, Device>& right_hchi,
152154 psi::Psi<std::complex <lowTYPE>, Device>& batch_vchi,
153155 psi::Psi<std::complex <lowTYPE>, Device>& batch_vhchi,
154156#ifdef __MPI
@@ -160,6 +162,7 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
160162 const int & bsize_psi,
161163 std::complex <lowTYPE>* j1,
162164 std::complex <lowTYPE>* j2,
165+ std::complex <lowTYPE>* tmpj,
163166 hamilt::Velocity<lowTYPE, Device>& velop,
164167 const int & ik,
165168 const std::complex <lowTYPE>& factor,
@@ -181,8 +184,6 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
181184 const int allbands = bandinfo[5 ];
182185 const int dim_jmatrix = perbands_ks * allbands_sto + perbands_sto * allbands;
183186
184- psi::Psi<std::complex <lowTYPE>, Device> right_hchi (1 , perbands_sto, npwx, npw, true );
185-
186187 hamilt->hPsi (leftchi.get_pointer (), left_hchi.get_pointer (), perbands_sto);
187188 hamilt->hPsi (rightchi.get_pointer (), right_hchi.get_pointer (), perbands_sto);
188189
@@ -261,8 +262,6 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
261262 }
262263 }
263264
264- std::complex <lowTYPE>* tmpj = nullptr ;
265- resmem_lcomplex_op ()(tmpj, allbands_sto * perbands_sto);
266265 int remain = perbands_sto;
267266 int startnb = 0 ;
268267 while (remain > 0 )
@@ -289,7 +288,7 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
289288 allbands_ks,
290289 npw,
291290 &float_factor,
292- batch_vchi. get_pointer ( ),
291+ & batch_vchi (idnb, 0 ),
293292 npwx,
294293 kspsi_all.get_pointer (),
295294 npwx,
@@ -316,7 +315,7 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
316315 allbands_ks,
317316 npw,
318317 &float_factor,
319- batch_vhchi. get_pointer ( ),
318+ & batch_vhchi (idnb, 0 ),
320319 npwx,
321320 kspsi_all.get_pointer (),
322321 npwx,
@@ -342,7 +341,7 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
342341 allbands_sto,
343342 npw,
344343 &float_factor,
345- batch_vchi. get_pointer ( ),
344+ & batch_vchi (idnb, 0 ),
346345 npwx,
347346 rightchi_all->get_pointer (),
348347 npwx,
@@ -357,9 +356,9 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
357356 allbands_sto,
358357 npw,
359358 &float_factor,
360- batch_vhchi. get_pointer ( ),
359+ & batch_vhchi (idnb, 0 ),
361360 npwx,
362- righthchi_all ->get_pointer (),
361+ rightchi_all ->get_pointer (),
363362 npwx,
364363 &zero,
365364 j2mat,
@@ -372,9 +371,9 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
372371 allbands_sto,
373372 npw,
374373 &float_factor,
375- batch_vchi. get_pointer ( ),
374+ & batch_vchi (idnb, 0 ),
376375 npwx,
377- rightchi_all ->get_pointer (),
376+ righthchi_all ->get_pointer (),
378377 npwx,
379378 &zero,
380379 tmpjmat,
@@ -470,7 +469,6 @@ void Sto_EleCond<FPTYPE, Device>::cal_jmatrix(hamilt::HamiltSdftPW<std::complex<
470469 Parallel_Common::reduce_data (j2, ndim * dim_jmatrix, POOL_WORLD);
471470 }
472471#endif
473-
474472 ModuleBase::timer::tick (" Sto_EleCond" , " cal_jmatrix" );
475473
476474 return ;
@@ -550,9 +548,9 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
550548 std::complex <lowTYPE> zero = static_cast <std::complex <lowTYPE>>(0.0 );
551549 std::complex <lowTYPE> imag_one = static_cast <std::complex <lowTYPE>>(ModuleBase::IMAG_UNIT);
552550 Sto_Func<lowTYPE> lowfunc;
553- lowTYPE low_emin = static_cast <lowTYPE>(*this ->stofunc .Emin );
554- lowTYPE low_emax = static_cast <lowTYPE>(*this ->stofunc .Emax );
555- lowfunc.set_E_range (&low_emin , &low_emax );
551+ this -> low_emin_ = static_cast <lowTYPE>(*this ->stofunc .Emin );
552+ this -> low_emax_ = static_cast <lowTYPE>(*this ->stofunc .Emax );
553+ lowfunc.set_E_range (&low_emin_ , &low_emax_ );
556554 hamilt::HamiltSdftPW<lcomplex, Device>* p_low_hamilt = nullptr ;
557555 if (hamilt_sto_ != nullptr )
558556 {
@@ -593,9 +591,9 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
593591 // std::complex<lowTYPE>* tmpcoef = batchcoef_ + (nbatch - 1) * cond_nche;
594592 // resmem_lcomplex_op()(batchmcoef_, cond_nche * nbatch);
595593 // std::complex<lowTYPE>* tmpmcoef = batchmcoef_ + (nbatch - 1) * cond_nche;
596- batchcoef.reshape ({nbatch, cond_nche});
594+ batchcoef.resize ({nbatch, cond_nche});
597595 lcomplex* tmpcoef = batchcoef[nbatch-1 ].data <lcomplex>();
598- batchmcoef.reshape ({nbatch, cond_nche});
596+ batchmcoef.resize ({nbatch, cond_nche});
599597 lcomplex* tmpmcoef = batchmcoef[nbatch-1 ].data <lcomplex>();
600598
601599 cpymem_lcomplex_op ()(tmpcoef, chet.coef_complex , cond_nche);
@@ -635,8 +633,8 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
635633
636634 // get allbands_ks
637635 int cutib0 = 0 ;
638- double emin = static_cast <double >(*this ->stofunc .Emin );
639- double emax = static_cast <double >(*this ->stofunc .Emax );
636+ const double emin = static_cast <double >(*this ->stofunc .Emin );
637+ const double emax = static_cast <double >(*this ->stofunc .Emax );
640638 if (this ->nbands_ks > 0 )
641639 {
642640 double Emax_KS = std::max (emin, this ->p_elec ->ekb (ik, this ->nbands_ks - 1 ));
@@ -697,7 +695,7 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
697695 // -----------------------------------------------------------
698696 if (GlobalV::MY_BNDGROUP == 0 && allbands_ks > 0 )
699697 {
700- jjresponse_ks (ik, nt, dt, dEcut, this ->p_elec ->wg , velop, ct11.data (), ct12.data (), ct22.data ());
698+ this -> jjresponse_ks (ik, nt, dt, dEcut, this ->p_elec ->wg , velop, ct11.data (), ct12.data (), ct22.data ());
701699 }
702700
703701 // -----------------------------------------------------------
@@ -823,11 +821,14 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
823821 ct::Tensor j2l (t_type, device_type, {ndim, dim_jmatrix});
824822 ct::Tensor j1r (t_type, device_type, {ndim, dim_jmatrix});
825823 ct::Tensor j2r (t_type, device_type, {ndim, dim_jmatrix});
824+ ct::Tensor tmpj (t_type, device_type, {ndim, allbands_sto * perbands_sto});
826825 ModuleBase::Memory::record (" SDFT::j1l" , sizeof (lcomplex) * ndim * dim_jmatrix);
827826 ModuleBase::Memory::record (" SDFT::j2l" , sizeof (lcomplex) * ndim * dim_jmatrix);
828827 ModuleBase::Memory::record (" SDFT::j1r" , sizeof (lcomplex) * ndim * dim_jmatrix);
829828 ModuleBase::Memory::record (" SDFT::j2r" , sizeof (lcomplex) * ndim * dim_jmatrix);
829+ ModuleBase::Memory::record (" SDFT::tmpj" , sizeof (lcomplex) * ndim * allbands_sto * perbands_sto);
830830 psi::Psi<lcomplex, Device> tmphchil (1 , perbands_sto, npwx, npw, true );
831+ psi::Psi<lcomplex, Device> tmphchir (1 , perbands_sto, npwx, npw, true );
831832 ModuleBase::Memory::record (" SDFT::tmphchil/r" , sto_memory_cost * 2 );
832833
833834 // ------------------------ t loop --------------------------
@@ -978,6 +979,7 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
978979 exptsmfchi,
979980 exptsfchi,
980981 tmphchil,
982+ tmphchir,
981983 batch_vchi,
982984 batch_vhchi,
983985#ifdef __MPI
@@ -989,6 +991,7 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
989991 bsize_psi,
990992 j1l.data <lcomplex>(),
991993 j2l.data <lcomplex>(),
994+ tmpj.data <lcomplex>(),
992995 low_velop,
993996 ik,
994997 imag_one,
@@ -1005,6 +1008,7 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
10051008 expmtsmfchi,
10061009 expmtsfchi,
10071010 tmphchil,
1011+ tmphchir,
10081012 batch_vchi,
10091013 batch_vhchi,
10101014#ifdef __MPI
@@ -1016,6 +1020,7 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
10161020 bsize_psi,
10171021 j1r.data <lcomplex>(),
10181022 j2r.data <lcomplex>(),
1023+ tmpj.data <lcomplex>(),
10191024 low_velop,
10201025 ik,
10211026 one,
@@ -1028,21 +1033,30 @@ void Sto_EleCond<FPTYPE, Device>::sKG(const int& smear_type,
10281033 // Im(l_ij*r_ji) = Re(-il_ij * r_ji) = Re( ((il)^+_ji)^* * r_ji)=Re(((il)^+_i)^* * r^+_i)
10291034 // ddot_real = real(A_i^* * B_i)
10301035 ModuleBase::timer::tick (" Sto_EleCond" , " ddot_real" );
1031- ct11[it] += static_cast <double >(
1032- ModuleBase::GlobalFunc::ddot_real (num_per, j1l.data <lcomplex>() + st_per, j1r.data <lcomplex>() + st_per, false )
1033- * this ->p_kv ->wk [ik] / 2.0 );
1034- double tmp12 = static_cast <double >(
1035- ModuleBase::GlobalFunc::ddot_real (num_per, j1l.data <lcomplex>() + st_per, j2r.data <lcomplex>() + st_per, false ));
1036-
1037- double tmp21 = static_cast <double >(
1038- ModuleBase::GlobalFunc::ddot_real (num_per, j2l.data <lcomplex>() + st_per, j1r.data <lcomplex>() + st_per, false ));
1036+ ct11[it] += static_cast <double >(ModuleBase::dot_real_op<lcomplex, Device>()(num_per,
1037+ j1l.data <lcomplex>() + st_per,
1038+ j1r.data <lcomplex>() + st_per,
1039+ false )
1040+ * this ->p_kv ->wk [ik] / 2.0 );
1041+ double tmp12
1042+ = static_cast <double >(ModuleBase::dot_real_op<lcomplex, Device>()(num_per,
1043+ j1l.data <lcomplex>() + st_per,
1044+ j2r.data <lcomplex>() + st_per,
1045+ false ));
1046+
1047+ double tmp21
1048+ = static_cast <double >(ModuleBase::dot_real_op<lcomplex, Device>()(num_per,
1049+ j2l.data <lcomplex>() + st_per,
1050+ j1r.data <lcomplex>() + st_per,
1051+ false ));
10391052
10401053 ct12[it] -= 0.5 * (tmp12 + tmp21) * this ->p_kv ->wk [ik] / 2.0 ;
10411054
1042- ct22[it] += static_cast <double >(
1043- ModuleBase::GlobalFunc::ddot_real (num_per, j2l.data <lcomplex>() + st_per, j2r.data <lcomplex>() + st_per, false )
1044- * this ->p_kv ->wk [ik] / 2.0 );
1045-
1055+ ct22[it] += static_cast <double >(ModuleBase::dot_real_op<lcomplex, Device>()(num_per,
1056+ j2l.data <lcomplex>() + st_per,
1057+ j2r.data <lcomplex>() + st_per,
1058+ false )
1059+ * this ->p_kv ->wk [ik] / 2.0 );
10461060 ModuleBase::timer::tick (" Sto_EleCond" , " ddot_real" );
10471061 }
10481062 std::cout << std::endl;
0 commit comments