Skip to content

Commit 0be3cda

Browse files
committed
LR-Grad exx force
1 parent 7fe3a1d commit 0be3cda

File tree

5 files changed

+164
-22
lines changed

5 files changed

+164
-22
lines changed

source/module_lr/Grad/esolver_lr_grad.cpp

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "module_lr/Grad/force/lr_force.h"
55
#include "module_elecstate/module_dm/cal_dm_psi.h"
66
#include "module_io/output_log.h"
7+
#ifdef __EXX
8+
#include <RI/ri/LRI_Cal_Aux.h>
9+
#endif
710

811
template <typename Tstream>
912
inline void print_force(const std::vector<ModuleBase::matrix>& force, Tstream& ofs)
@@ -59,6 +62,41 @@ inline void test_edm_H2(const T* const edm, const double* const eig_ks, const ps
5962
}
6063
}
6164

65+
#ifdef __EXX
66+
// convert DensityMatrix to maps of RI::Tensors
67+
template <typename T>
68+
inline auto get_exx_Ds_spin1(const elecstate::DensityMatrix<T, T>& dm,
69+
const UnitCell& ucell, const K_Vectors& kv, const Parallel_Orbitals& pmat)
70+
-> std::map<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<T>>>
71+
{
72+
const int& nk = dm.get_DMK_nks(); // nks/nspin
73+
std::vector<const std::vector<T>*> DMk_trans_pointer(nk);
74+
for (int ik = 0;ik < nk;++ik) { DMk_trans_pointer[ik] = &dm.get_DMK_vector()[ik]; }
75+
return RI_2D_Comm::split_m2D_ktoR<T>(ucell, kv, DMk_trans_pointer, pmat, /*nspin=*/1)[0];
76+
}
77+
template <typename T>
78+
inline auto get_exx_Ds_gs(elecstate::DensityMatrix<T, double>& dm,
79+
const UnitCell& ucell, const K_Vectors& kv, const Parallel_Orbitals& pmat, const int nspin)
80+
-> std::map<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<T>>>
81+
{
82+
const int& nk = dm.get_DMK_nks() / nspin; // nks/nspin
83+
std::vector<const std::vector<T>*> DMk_trans_pointer(nk);
84+
85+
std::map<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<T>>> Ds_allspin;
86+
for (int is = 0;is < nspin;++is)
87+
{
88+
for (int ik = 0;ik < nk;++ik)
89+
{
90+
DMk_trans_pointer[ik] = &dm.get_DMK_vector()[ik + is * nk];
91+
}
92+
std::map<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<T>>> Ds_tmp
93+
= RI_2D_Comm::split_m2D_ktoR<T>(ucell, kv, DMk_trans_pointer, pmat, /*nspin=*/1)[0];
94+
RI::LRI_Cal_Aux::add_Ds<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<T>>>(std::move(Ds_tmp), Ds_allspin);
95+
}
96+
return Ds_allspin;
97+
}
98+
#endif
99+
62100
template<typename T, typename TR>
63101
void LR::ESolver_LR<T, TR>::init_pot_groundstate(const Charge& chg_gs)
64102
{
@@ -129,7 +167,11 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
129167
const auto& c = LR_Util::get_psi_spin(*this->psi_ks, ispin, this->nk); // wavefunction coefficients of ground state
130168

131169
// calculate the force (the partial gradient of Lagrangian)
132-
LR_Force<T> lr_force(this->ucell, this->kv.kvec_d, this->paraMat_, *this->pw_rho, this->locpp, this->sf, this->gd, this->gint_, this->two_center_bundle_);
170+
LR_Force<T> lr_force(this->ucell, this->kv.kvec_d, this->paraMat_, *this->pw_rho, this->locpp, this->sf, this->gd, this->gint_, this->two_center_bundle_
171+
#ifdef __EXX
172+
, std::weak_ptr<Exx_LRI<T>>(this->exx_lri), this->exx_info.info_global.hybrid_alpha
173+
#endif
174+
);
133175
GlobalV::ofs_running << "Start to calculate excited-state force of " << this->spin_types[ispin] << std::endl;
134176
// ground state dm for currrent spin (only for test the correctness of the force)
135177
// elecstate::DensityMatrix<T, T> dm_gs(this->paraMat_, 1, this->kv.kvec_d, this->nk);
@@ -141,9 +183,12 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
141183
const int offset = istate * this->nloc_per_band;
142184
// The imag part will be cancelled in the force calculation, so we use double DM(R) to calculate force.
143185
// But complex transition DM(R) is still used in energy density matrix calculation.
144-
elecstate::DensityMatrix<T, double> dm_trans_real = // D(X)
145-
LR_Util::build_dm_from_dmk<T, double>(
146-
cal_dm_trans_pblas(this->X[ispin].template data<T>() + offset, this->paraX_[ispin], c, this->paraC_, this->nbasis, this->nocc[ispin], this->nvirt[ispin], this->paraMat_),
186+
const auto& dm_trans_k = cal_dm_trans_pblas(this->X[ispin].template data<T>() + offset, this->paraX_[ispin], c, this->paraC_, this->nbasis, this->nocc[ispin], this->nvirt[ispin], this->paraMat_);
187+
const elecstate::DensityMatrix<T, double>& dm_trans_real = // D(X), double (FIXME: not enough for periodic system!)
188+
LR_Util::build_dm_from_dmk<T, double>(dm_trans_k,
189+
this->paraMat_, this->nk, this->kv.kvec_d, this->ucell, this->gd, this->orb_cutoff_);
190+
const elecstate::DensityMatrix<T, T>& dm_trans = // D(X) complex
191+
LR_Util::build_dm_from_dmk<T, T>(dm_trans_k,
147192
this->paraMat_, this->nk, this->kv.kvec_d, this->ucell, this->gd, this->orb_cutoff_);
148193

149194
elecstate::DensityMatrix<T, T> relaxed_diff_dm = // T+D(Z), (R) can be complex
@@ -199,13 +244,35 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
199244
ModuleBase::matrix force_hxc_dmtrans = lr_force.cal_force_hxc_dmtrans(dm_trans_real, *this->pot[ispin]);
200245
std::cout << "Force (Hxc-DMTrans term) of state " << istate << ": " << std::endl;
201246
LR_Util::print_value(force_hxc_dmtrans.c, ucell.nat, 3);
247+
202248
ModuleBase::matrix force_hamiltgs_relaxed_diff = lr_force.cal_force_hamilt_gs_dm_relaxed_diff(relaxed_diff_dm_real, *pot_gs, /*with_ewald=*/false);
203249
std::cout << "Force (GS-(T+Z) term) of state " << istate << ": " << std::endl;
204250
LR_Util::print_value(force_hamiltgs_relaxed_diff.c, ucell.nat, 3);
251+
205252
ModuleBase::matrix force_overlap_edm = lr_force.cal_force_overlap_edm(edm_real); // "-" sign has been included in the force factor
206253
std::cout << "Force (Overlap-EDM term) of state " << istate << ": " << std::endl;
207254
LR_Util::print_value(force_overlap_edm.c, ucell.nat, 3);
208-
// remaining for EXX
255+
256+
#ifdef __EXX
257+
const double& alpha = this->exx_info.info_global.hybrid_alpha;
258+
259+
const auto& Ds_trans = get_exx_Ds_spin1(dm_trans, this->ucell, this->kv, this->paraMat_);
260+
ModuleBase::matrix force_exx_dmtrans = lr_force.cal_force_exx_dm_trans(Ds_trans, alpha);
261+
std::cout << "Force (EXX-DMTrans term) of state " << istate << ": " << std::endl;
262+
LR_Util::print_value(force_exx_dmtrans.c, ucell.nat, 3);
263+
264+
elecstate::DensityMatrix<T, double> dm_gs(&this->paraMat_, this->nspin, this->kv.kvec_d, this->nk); //DX
265+
elecstate::cal_dm_psi(&this->paraMat_all_, this->wg_ks_all, *this->psi_ks_all, dm_gs); // nbands is important here
266+
const auto& Ds_gs = get_exx_Ds_gs(dm_gs, this->ucell, this->kv, this->paraMat_, this->nspin);
267+
const auto& Ds_relaxed_diff = get_exx_Ds_spin1(relaxed_diff_dm, this->ucell, this->kv, this->paraMat_);
268+
ModuleBase::matrix force_exx_gs_diff = lr_force.cal_force_exx_gs_dm_relaxed_diff(Ds_gs, Ds_relaxed_diff, alpha);
269+
std::cout << "Force (EXX-GS-(T+Z) term) of state " << istate << ": " << std::endl;
270+
LR_Util::print_value(force_exx_gs_diff.c, ucell.nat, 3);
271+
272+
// add exx force to the corresponding terms
273+
force_hxc_dmtrans += force_exx_dmtrans;
274+
force_hamiltgs_relaxed_diff += force_exx_gs_diff;
275+
#endif
209276
forces[istate] = force_hxc_dmtrans + force_hamiltgs_relaxed_diff + force_overlap_edm;
210277
}
211278
ModuleBase::timer::tick("ESolver_LR", "cal_force");
@@ -217,7 +284,11 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
217284
template<typename T, typename TR>
218285
void LR::ESolver_LR<T, TR>::test_force()
219286
{
220-
LR_Force<T> lr_force(this->ucell, this->kv.kvec_d, this->paraMat_, *this->pw_rho, this->locpp, this->sf, this->gd, this->gint_, this->two_center_bundle_);
287+
LR_Force<T> lr_force(this->ucell, this->kv.kvec_d, this->paraMat_, *this->pw_rho, this->locpp, this->sf, this->gd, this->gint_, this->two_center_bundle_
288+
#ifdef __EXX
289+
, std::weak_ptr<Exx_LRI<T>>(this->exx_lri), this->exx_info.info_global.hybrid_alpha
290+
#endif
291+
);
221292

222293
elecstate::DensityMatrix<T, double> dm_gs(&this->paraMat_, this->nspin, this->kv.kvec_d, this->nk); //DX
223294
elecstate::cal_dm_psi(&this->paraMat_all_, this->wg_ks_all, *this->psi_ks_all, dm_gs); // nbands is important here

source/module_lr/Grad/force/lr_force.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,49 @@ namespace LR
112112
this->gint_->reset_DMRGint(1);
113113
return fvl_dphi;
114114
}
115+
116+
#ifdef __EXX
117+
template<typename TK>
118+
ModuleBase::matrix LR_Force<TK>::cal_force_exx_dm_trans(
119+
const std::map<int, std::map<TAC, RI::Tensor<TK>>>& dm_trans,
120+
const double& alpha)
121+
{
122+
ModuleBase::matrix f_exx_dmtrans(this->ucell_.nat, 3);
123+
auto& exx_lri_kernel = this->exx_lri_.lock()->get();
124+
exx_lri_kernel.set_Ds(dm_trans, this->exx_lri_.lock()->get_info().dm_threshold, "0");
125+
exx_lri_kernel.cal_Hs();
126+
exx_lri_kernel.cal_force();// using dm_trans
127+
for (std::size_t idim = 0; idim < 3; ++idim)
128+
for (const auto& force_item : exx_lri_kernel.force[idim])
129+
f_exx_dmtrans(force_item.first, idim) = std::real(force_item.second);
130+
return f_exx_dmtrans * alpha;
131+
}
132+
133+
template<typename TK>
134+
ModuleBase::matrix LR_Force<TK>::cal_force_exx_gs_dm_relaxed_diff(
135+
const std::map<int, std::map<TAC, RI::Tensor<TK>>>& dm_gs,
136+
const std::map<int, std::map<TAC, RI::Tensor<TK>>>& relaxed_diff_dm,
137+
const double& alpha)
138+
{
139+
ModuleBase::matrix f_exx_gs_diff(this->ucell_.nat, 3);
140+
auto& exx_lri_kernel = this->exx_lri_.lock()->get();
141+
exx_lri_kernel.set_Ds(dm_gs, this->exx_lri_.lock()->get_info().dm_threshold, "0");
142+
exx_lri_kernel.cal_Hs(); // using dm_gs
143+
// auto* lr_ptr = dynamic_cast<RI::LR<int, int, 3, TK>*>(&exx_lri_kernel); // wrong:
144+
// assert(lr_ptr != nullptr);
145+
RI::LR<int, int, 3, TK> lr_exx_kernel(std::move(exx_lri_kernel));
146+
std::cout << "post_2D adress moved: " << &(lr_exx_kernel.post_2D) << std::endl;
147+
std::cout << "begin RI::LR::cal_force" << std::endl;
148+
lr_exx_kernel.cal_force(relaxed_diff_dm); // using relaxed_diff_dm
149+
exx_lri_kernel = std::move(lr_exx_kernel); // move back
150+
std::cout << "post_2D adress after move back: " << &(exx_lri_kernel.post_2D) << std::endl;
151+
std::cout << "end RI::LR::cal_force" << std::endl;
152+
for (std::size_t idim = 0; idim < 3; ++idim)
153+
for (const auto& force_item : exx_lri_kernel.force[idim])
154+
f_exx_gs_diff(force_item.first, idim) = std::real(force_item.second);
155+
return f_exx_gs_diff * alpha;
156+
}
157+
#endif
115158
}
116159

117160
template class LR::LR_Force<double>;

source/module_lr/Grad/force/lr_force.h

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
#include "module_lr/potentials/pot_hxc_lrtd.h"
33
#include "module_lr/utils/gint_template.h"
44
// free functions, usefull for both ground and excited state
5-
5+
#ifdef __EXX
6+
#include "module_ri/Exx_LRI.h"
7+
#include <RI/physics/LR.h>
8+
using TAC = std::pair<int, std::array<int, 3>>;
9+
#endif
610
namespace LR
711
{
812

@@ -18,10 +22,19 @@ namespace LR
1822
const Structure_Factor& sf,
1923
const Grid_Driver& gd,
2024
typename TGint<TK>::type* gint, ///< for grid integrals
21-
const TwoCenterBundle& two_center_bundle)///< for 2-center integrals
25+
const TwoCenterBundle& two_center_bundle
26+
#ifdef __EXX
27+
, std::weak_ptr<Exx_LRI<TK>> exx_lri_in,
28+
const double& alpha
29+
#endif
30+
)///< for 2-center integrals
2231
: ucell_(ucell), kvec_d_(kvec_d), pv_(pv),
2332
rhopw_(rhopw), sf_(sf), locpp_(locpp), gd_(gd),
24-
gint_(gint), two_center_bundle_(two_center_bundle) {
33+
gint_(gint), two_center_bundle_(two_center_bundle)
34+
#ifdef __EXX
35+
, exx_lri_(exx_lri_in), alpha_(alpha)
36+
#endif
37+
{
2538
}
2639

2740
/// 1. $Tr[H_{GS}^x * (T+D^Z)]$, where GS=groud state and $(T+D^Z)$ is the relaxed difference density matrix
@@ -35,11 +48,18 @@ namespace LR
3548
ModuleBase::matrix cal_force_hxc_dmtrans(const elecstate::DensityMatrix<TK, double>& dm_trans, const PotHxcLR& pot_hxc);
3649

3750
#ifdef __EXX
51+
// auto* lrexx_ptr = dynamic_cast<RI::LR<int, std::array<int, 3>, 3, TK>*>(&exx_lri_in.get());
3852
/// 4. $\alpha \sum_{mnkl}(mk|nl)^x *D^X *D^X$
39-
// ModuleBase::matrix cal_force_exx_dm_diff(const elecstate::DensityMatrix<TK, double>& dm_trans);
53+
ModuleBase::matrix cal_force_exx_dm_trans(
54+
const std::map<int, std::map<TAC, RI::Tensor<TK>>>& dm_trans,
55+
const double& alpha);
56+
ModuleBase::matrix cal_force_exx_gs_dm_relaxed_diff(
57+
const std::map<int, std::map<TAC, RI::Tensor<TK>>>& dm_gs,
58+
const std::map<int, std::map<TAC, RI::Tensor<TK>>>& relaxed_diff_dm,
59+
const double& alpha);
4060
#endif
4161

42-
/// test functions
62+
// test functions
4363
/// reproduce the force of the ground state
4464
ModuleBase::matrix reproduce_force_gs(const elecstate::DensityMatrix<TK, double>& dm_gs,
4565
const elecstate::DensityMatrix<TK, double>& edm_gs,
@@ -58,6 +78,10 @@ namespace LR
5878
const Grid_Driver& gd_;
5979
typename TGint<TK>::type* gint_;
6080
const TwoCenterBundle& two_center_bundle_;
81+
#ifdef __EXX
82+
std::weak_ptr<Exx_LRI<TK>> exx_lri_;
83+
const double alpha_;
84+
#endif
6185

6286
Charge dm_to_charge(const elecstate::DensityMatrix<TK, double>& dm);
6387

source/module_lr/dm_band/dm_band.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,18 @@ namespace LR
2727
{
2828
const int iat1 = ucell_.itia2iat(it1, ia1);
2929
const int iat2 = ucell_.itia2iat(it2, ia2);
30-
auto& D2d = dm_band[iat1][std::make_pair(iat2, cell)];
31-
const int nw1 = ucell_.atoms[it1].nw;
32-
const int nw2 = ucell_.atoms[it2].nw;
30+
const std::size_t nw1 = ucell_.atoms[it1].nw;
31+
const std::size_t nw2 = ucell_.atoms[it2].nw;
32+
RI::Tensor<double> dm_tmp({ nw1, nw2 });
3333
for (int iw1 = 0;iw1 < nw1;++iw1)
3434
for (int iw2 = 0;iw2 < nw2;++iw2)
3535
{
3636
const int iwt1 = use_nws1 ? nws1[it1] : ucell_.itiaiw2iwt(it1, ia1, iw1);
3737
const int iwt2 = use_nws2 ? nws2[it2] : ucell_.itiaiw2iwt(it2, ia2, iw2);
3838
if (pmat_.in_this_processor(iwt1, iwt2))
39-
D2d(iw1, iw2) = fac * c1_(ik, iband1, iwt1) * c2_(ik, iband2, iwt2);
39+
dm_tmp(iw1, iw2) = fac * c1_(ik, iband1, iwt1) * c2_(ik, iband2, iwt2);
4040
}
41+
dm_band[iat1][std::make_pair(iat2, cell)] = dm_tmp;
4142
}
4243
}
4344
}
@@ -56,25 +57,26 @@ namespace LR
5657
for (auto cell : bvk_cells_)
5758
{
5859
std::complex<double> fac_phase = RI::Global_Func::convert<std::complex<double>>(std::exp(
59-
-ModuleBase::TWO_PI * ModuleBase::IMAG_UNIT * (kvec_c_.at(ik) * (RI_Util::array3_to_Vector3(cell) * ucell_.latvec))));
60+
-ModuleBase::TWO_PI * ModuleBase::IMAG_UNIT * (kvec_c_.at(ik) * (RI_Util::array3_to_Vector3(cell) * ucell_.latvec)))) * fac;
6061
for (int it1 = 0;it1 < ucell_.ntype;++it1)
6162
for (int ia1 = 0; ia1 < ucell_.atoms[it1].na; ++ia1)
6263
for (int it2 = 0;it2 < ucell_.ntype;++it2)
6364
for (int ia2 = 0;ia2 < ucell_.atoms[it2].na;++ia2)
6465
{
65-
int iat1 = ucell_.itia2iat(it1, ia1);
66-
int iat2 = ucell_.itia2iat(it2, ia2);
67-
auto& D2d = dm_band[iat1][std::make_pair(iat2, cell)];
68-
const int nw1 = ucell_.atoms[it1].nw;
69-
const int nw2 = ucell_.atoms[it2].nw;
66+
const int iat1 = ucell_.itia2iat(it1, ia1);
67+
const int iat2 = ucell_.itia2iat(it2, ia2);
68+
const std::size_t nw1 = ucell_.atoms[it1].nw;
69+
const std::size_t nw2 = ucell_.atoms[it2].nw;
70+
RI::Tensor<std::complex<double>> dm_tmp({ nw1, nw2 });
7071
for (int iw1 = 0;iw1 < nw1;++iw1)
7172
for (int iw2 = 0;iw2 < nw2;++iw2)
7273
{
7374
const int iwt1 = use_nws1 ? nws1[it1] : ucell_.itiaiw2iwt(it1, ia1, iw1);
7475
const int iwt2 = use_nws2 ? nws2[it2] : ucell_.itiaiw2iwt(it2, ia2, iw2);
7576
if (pmat_.in_this_processor(iwt1, iwt2))
76-
D2d(iw1, iw2) = fac * c1_(ik, iband1, iwt1) * std::conj(c2_(ik, iband2, iwt2));
77+
dm_tmp(iw1, iw2) = fac_phase * c1_(ik, iband1, iwt1) * std::conj(c2_(ik, iband2, iwt2));
7778
}
79+
dm_band[iat1][std::make_pair(iat2, cell)] = dm_tmp;
7880
}
7981
}
8082
}

source/module_ri/Exx_LRI.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class Exx_LRI
5353
Exx_LRI operator=(const Exx_LRI&) = delete;
5454
Exx_LRI operator=(Exx_LRI&&);
5555

56+
RI::Exx<TA, Tcell, Ndim, Tdata>& get() { return this->exx_lri; }
57+
auto& get_info() const { return this->info; }
5658
void reset_Cs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); }
5759
void reset_Vs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); }
5860

0 commit comments

Comments
 (0)