44#include " module_lr/Grad/force/lr_force.h"
55#include " module_elecstate/module_dm/cal_dm_psi.h"
66#include " module_io/output_log.h"
7+ #ifdef __EXX
8+ #include < RI/ri/LRI_Cal_Aux.h>
9+ #endif
710
811template <typename Tstream>
912inline void print_force (const std::vector<ModuleBase::matrix>& force, Tstream& ofs)
@@ -59,6 +62,41 @@ inline void test_edm_H2(const T* const edm, const double* const eig_ks, const ps
5962 }
6063}
6164
65+ #ifdef __EXX
66+ // convert DensityMatrix to maps of RI::Tensors
67+ template <typename T>
68+ inline auto get_exx_Ds_spin1 (const elecstate::DensityMatrix<T, T>& dm,
69+ const UnitCell& ucell, const K_Vectors& kv, const Parallel_Orbitals& pmat)
70+ -> std::map<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<T>>>
71+ {
72+ const int & nk = dm.get_DMK_nks (); // nks/nspin
73+ std::vector<const std::vector<T>*> DMk_trans_pointer (nk);
74+ for (int ik = 0 ;ik < nk;++ik) { DMk_trans_pointer[ik] = &dm.get_DMK_vector ()[ik]; }
75+ return RI_2D_Comm::split_m2D_ktoR<T>(ucell, kv, DMk_trans_pointer, pmat, /* nspin=*/ 1 )[0 ];
76+ }
77+ template <typename T>
78+ inline auto get_exx_Ds_gs (elecstate::DensityMatrix<T, double >& dm,
79+ const UnitCell& ucell, const K_Vectors& kv, const Parallel_Orbitals& pmat, const int nspin)
80+ -> std::map<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<T>>>
81+ {
82+ const int & nk = dm.get_DMK_nks () / nspin; // nks/nspin
83+ std::vector<const std::vector<T>*> DMk_trans_pointer (nk);
84+
85+ std::map<int , std::map<std::pair<int , std::array<int , 3 >>, RI::Tensor<T>>> Ds_allspin;
86+ for (int is = 0 ;is < nspin;++is)
87+ {
88+ for (int ik = 0 ;ik < nk;++ik)
89+ {
90+ DMk_trans_pointer[ik] = &dm.get_DMK_vector ()[ik + is * nk];
91+ }
92+ std::map<int , std::map<std::pair<int , std::array<int , 3 >>, RI::Tensor<T>>> Ds_tmp
93+ = RI_2D_Comm::split_m2D_ktoR<T>(ucell, kv, DMk_trans_pointer, pmat, /* nspin=*/ 1 )[0 ];
94+ RI::LRI_Cal_Aux::add_Ds<int , std::map<std::pair<int , std::array<int , 3 >>, RI::Tensor<T>>>(std::move (Ds_tmp), Ds_allspin);
95+ }
96+ return Ds_allspin;
97+ }
98+ #endif
99+
62100template <typename T, typename TR>
63101void LR::ESolver_LR<T, TR>::init_pot_groundstate(const Charge& chg_gs)
64102{
@@ -129,7 +167,11 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
129167 const auto & c = LR_Util::get_psi_spin (*this ->psi_ks , ispin, this ->nk ); // wavefunction coefficients of ground state
130168
131169 // calculate the force (the partial gradient of Lagrangian)
132- LR_Force<T> lr_force (this ->ucell , this ->kv .kvec_d , this ->paraMat_ , *this ->pw_rho , this ->locpp , this ->sf , this ->gd , this ->gint_ , this ->two_center_bundle_ );
170+ LR_Force<T> lr_force (this ->ucell , this ->kv .kvec_d , this ->paraMat_ , *this ->pw_rho , this ->locpp , this ->sf , this ->gd , this ->gint_ , this ->two_center_bundle_
171+ #ifdef __EXX
172+ , std::weak_ptr<Exx_LRI<T>>(this ->exx_lri ), this ->exx_info .info_global .hybrid_alpha
173+ #endif
174+ );
133175 GlobalV::ofs_running << " Start to calculate excited-state force of " << this ->spin_types [ispin] << std::endl;
134176 // ground state dm for currrent spin (only for test the correctness of the force)
135177 // elecstate::DensityMatrix<T, T> dm_gs(this->paraMat_, 1, this->kv.kvec_d, this->nk);
@@ -141,9 +183,12 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
141183 const int offset = istate * this ->nloc_per_band ;
142184 // The imag part will be cancelled in the force calculation, so we use double DM(R) to calculate force.
143185 // But complex transition DM(R) is still used in energy density matrix calculation.
144- elecstate::DensityMatrix<T, double > dm_trans_real = // D(X)
145- LR_Util::build_dm_from_dmk<T, double >(
146- cal_dm_trans_pblas (this ->X [ispin].template data <T>() + offset, this ->paraX_ [ispin], c, this ->paraC_ , this ->nbasis , this ->nocc [ispin], this ->nvirt [ispin], this ->paraMat_ ),
186+ const auto & dm_trans_k = cal_dm_trans_pblas (this ->X [ispin].template data <T>() + offset, this ->paraX_ [ispin], c, this ->paraC_ , this ->nbasis , this ->nocc [ispin], this ->nvirt [ispin], this ->paraMat_ );
187+ const elecstate::DensityMatrix<T, double >& dm_trans_real = // D(X), double (FIXME: not enough for periodic system!)
188+ LR_Util::build_dm_from_dmk<T, double >(dm_trans_k,
189+ this ->paraMat_ , this ->nk , this ->kv .kvec_d , this ->ucell , this ->gd , this ->orb_cutoff_ );
190+ const elecstate::DensityMatrix<T, T>& dm_trans = // D(X) complex
191+ LR_Util::build_dm_from_dmk<T, T>(dm_trans_k,
147192 this ->paraMat_ , this ->nk , this ->kv .kvec_d , this ->ucell , this ->gd , this ->orb_cutoff_ );
148193
149194 elecstate::DensityMatrix<T, T> relaxed_diff_dm = // T+D(Z), (R) can be complex
@@ -199,13 +244,35 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
199244 ModuleBase::matrix force_hxc_dmtrans = lr_force.cal_force_hxc_dmtrans (dm_trans_real, *this ->pot [ispin]);
200245 std::cout << " Force (Hxc-DMTrans term) of state " << istate << " : " << std::endl;
201246 LR_Util::print_value (force_hxc_dmtrans.c , ucell.nat , 3 );
247+
202248 ModuleBase::matrix force_hamiltgs_relaxed_diff = lr_force.cal_force_hamilt_gs_dm_relaxed_diff (relaxed_diff_dm_real, *pot_gs, /* with_ewald=*/ false );
203249 std::cout << " Force (GS-(T+Z) term) of state " << istate << " : " << std::endl;
204250 LR_Util::print_value (force_hamiltgs_relaxed_diff.c , ucell.nat , 3 );
251+
205252 ModuleBase::matrix force_overlap_edm = lr_force.cal_force_overlap_edm (edm_real); // "-" sign has been included in the force factor
206253 std::cout << " Force (Overlap-EDM term) of state " << istate << " : " << std::endl;
207254 LR_Util::print_value (force_overlap_edm.c , ucell.nat , 3 );
208- // remaining for EXX
255+
256+ #ifdef __EXX
257+ const double & alpha = this ->exx_info .info_global .hybrid_alpha ;
258+
259+ const auto & Ds_trans = get_exx_Ds_spin1 (dm_trans, this ->ucell , this ->kv , this ->paraMat_ );
260+ ModuleBase::matrix force_exx_dmtrans = lr_force.cal_force_exx_dm_trans (Ds_trans, alpha);
261+ std::cout << " Force (EXX-DMTrans term) of state " << istate << " : " << std::endl;
262+ LR_Util::print_value (force_exx_dmtrans.c , ucell.nat , 3 );
263+
264+ elecstate::DensityMatrix<T, double > dm_gs (&this ->paraMat_ , this ->nspin , this ->kv .kvec_d , this ->nk ); // DX
265+ elecstate::cal_dm_psi (&this ->paraMat_all_ , this ->wg_ks_all , *this ->psi_ks_all , dm_gs); // nbands is important here
266+ const auto & Ds_gs = get_exx_Ds_gs (dm_gs, this ->ucell , this ->kv , this ->paraMat_ , this ->nspin );
267+ const auto & Ds_relaxed_diff = get_exx_Ds_spin1 (relaxed_diff_dm, this ->ucell , this ->kv , this ->paraMat_ );
268+ ModuleBase::matrix force_exx_gs_diff = lr_force.cal_force_exx_gs_dm_relaxed_diff (Ds_gs, Ds_relaxed_diff, alpha);
269+ std::cout << " Force (EXX-GS-(T+Z) term) of state " << istate << " : " << std::endl;
270+ LR_Util::print_value (force_exx_gs_diff.c , ucell.nat , 3 );
271+
272+ // add exx force to the corresponding terms
273+ force_hxc_dmtrans += force_exx_dmtrans;
274+ force_hamiltgs_relaxed_diff += force_exx_gs_diff;
275+ #endif
209276 forces[istate] = force_hxc_dmtrans + force_hamiltgs_relaxed_diff + force_overlap_edm;
210277 }
211278 ModuleBase::timer::tick (" ESolver_LR" , " cal_force" );
@@ -217,7 +284,11 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
217284template <typename T, typename TR>
218285void LR::ESolver_LR<T, TR>::test_force()
219286{
220- LR_Force<T> lr_force (this ->ucell , this ->kv .kvec_d , this ->paraMat_ , *this ->pw_rho , this ->locpp , this ->sf , this ->gd , this ->gint_ , this ->two_center_bundle_ );
287+ LR_Force<T> lr_force (this ->ucell , this ->kv .kvec_d , this ->paraMat_ , *this ->pw_rho , this ->locpp , this ->sf , this ->gd , this ->gint_ , this ->two_center_bundle_
288+ #ifdef __EXX
289+ , std::weak_ptr<Exx_LRI<T>>(this ->exx_lri ), this ->exx_info .info_global .hybrid_alpha
290+ #endif
291+ );
221292
222293 elecstate::DensityMatrix<T, double > dm_gs (&this ->paraMat_ , this ->nspin , this ->kv .kvec_d , this ->nk ); // DX
223294 elecstate::cal_dm_psi (&this ->paraMat_all_ , this ->wg_ks_all , *this ->psi_ks_all , dm_gs); // nbands is important here
0 commit comments