Skip to content

Commit 14822f8

Browse files
committed
change ucell in module_ri/rpa_lri_tool.cpp
1 parent 579e58a commit 14822f8

File tree

4 files changed

+39
-29
lines changed

4 files changed

+39
-29
lines changed

source/module_ri/Exx_LRI.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,15 @@ void Exx_LRI<Tdata>::cal_exx_ions(const UnitCell& ucell,
133133
this->exx_lri.set_parallel(this->mpi_comm, atoms_pos, latvec, period);
134134

135135
// std::max(3) for gamma_only, list_A2 should contain cell {-1,0,1}. In the future distribute will be neighbour.
136-
const std::array<Tcell,Ndim> period_Vs = LRI_CV_Tools::cal_latvec_range<Tcell>(1+this->info.ccp_rmesh_times, orb_cutoff_);
136+
const std::array<Tcell,Ndim> period_Vs = LRI_CV_Tools::cal_latvec_range<Tcell>(1+this->info.ccp_rmesh_times, ucell, orb_cutoff_);
137137
const std::pair<std::vector<TA>, std::vector<std::vector<std::pair<TA,std::array<Tcell,Ndim>>>>>
138138
list_As_Vs = RI::Distribute_Equally::distribute_atoms_periods(this->mpi_comm, atoms, period_Vs, 2, false);
139139

140140
std::map<TA,std::map<TAC,RI::Tensor<Tdata>>>
141141
Vs = this->cv.cal_Vs(ucell,
142142
list_As_Vs.first, list_As_Vs.second[0],
143143
{{"writable_Vws",true}});
144-
this->cv.Vws = LRI_CV_Tools::get_CVws(Vs);
144+
this->cv.Vws = LRI_CV_Tools::get_CVws(ucell,Vs);
145145
if (write_cv && GlobalV::MY_RANK == 0) { LRI_CV_Tools::write_Vs_abf(Vs, PARAM.globalv.global_out_dir + "Vs"); }
146146
this->exx_lri.set_Vs(std::move(Vs), this->info.V_threshold);
147147

@@ -151,16 +151,16 @@ void Exx_LRI<Tdata>::cal_exx_ions(const UnitCell& ucell,
151151
dVs = this->cv.cal_dVs(ucell,
152152
list_As_Vs.first, list_As_Vs.second[0],
153153
{{"writable_dVws",true}});
154-
this->cv.dVws = LRI_CV_Tools::get_dCVws(dVs);
154+
this->cv.dVws = LRI_CV_Tools::get_dCVws(ucell,dVs);
155155
this->exx_lri.set_dVs(std::move(dVs), this->info.V_grad_threshold);
156156
if(PARAM.inp.cal_stress)
157157
{
158-
std::array<std::array<std::map<TA,std::map<TAC,RI::Tensor<Tdata>>>,3>,3> dVRs = LRI_CV_Tools::cal_dMRs(dVs);
158+
std::array<std::array<std::map<TA,std::map<TAC,RI::Tensor<Tdata>>>,3>,3> dVRs = LRI_CV_Tools::cal_dMRs(ucell,dVs);
159159
this->exx_lri.set_dVRs(std::move(dVRs), this->info.V_grad_R_threshold);
160160
}
161161
}
162162

163-
const std::array<Tcell,Ndim> period_Cs = LRI_CV_Tools::cal_latvec_range<Tcell>(2, orb_cutoff_);
163+
const std::array<Tcell,Ndim> period_Cs = LRI_CV_Tools::cal_latvec_range<Tcell>(2, ucell,orb_cutoff_);
164164
const std::pair<std::vector<TA>, std::vector<std::vector<std::pair<TA,std::array<Tcell,Ndim>>>>>
165165
list_As_Cs = RI::Distribute_Equally::distribute_atoms_periods(this->mpi_comm, atoms, period_Cs, 2, false);
166166

@@ -171,18 +171,18 @@ void Exx_LRI<Tdata>::cal_exx_ions(const UnitCell& ucell,
171171
{{"cal_dC",PARAM.inp.cal_force||PARAM.inp.cal_stress},
172172
{"writable_Cws",true}, {"writable_dCws",true}, {"writable_Vws",false}, {"writable_dVws",false}});
173173
std::map<TA,std::map<TAC,RI::Tensor<Tdata>>> &Cs = std::get<0>(Cs_dCs);
174-
this->cv.Cws = LRI_CV_Tools::get_CVws(Cs);
174+
this->cv.Cws = LRI_CV_Tools::get_CVws(ucell,Cs);
175175
if (write_cv && GlobalV::MY_RANK == 0) { LRI_CV_Tools::write_Cs_ao(Cs, PARAM.globalv.global_out_dir + "Cs"); }
176176
this->exx_lri.set_Cs(std::move(Cs), this->info.C_threshold);
177177

178178
if(PARAM.inp.cal_force || PARAM.inp.cal_stress)
179179
{
180180
std::array<std::map<TA,std::map<TAC,RI::Tensor<Tdata>>>,3> &dCs = std::get<1>(Cs_dCs);
181-
this->cv.dCws = LRI_CV_Tools::get_dCVws(dCs);
181+
this->cv.dCws = LRI_CV_Tools::get_dCVws(ucell,dCs);
182182
this->exx_lri.set_dCs(std::move(dCs), this->info.C_grad_threshold);
183183
if(PARAM.inp.cal_stress)
184184
{
185-
std::array<std::array<std::map<TA,std::map<TAC,RI::Tensor<Tdata>>>,3>,3> dCRs = LRI_CV_Tools::cal_dMRs(dCs);
185+
std::array<std::array<std::map<TA,std::map<TAC,RI::Tensor<Tdata>>>,3>,3> dCRs = LRI_CV_Tools::cal_dMRs(ucell,dCs);
186186
this->exx_lri.set_dCRs(std::move(dCRs), this->info.C_grad_R_threshold);
187187
}
188188
}

source/module_ri/LRI_CV_Tools.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,25 @@ namespace LRI_CV_Tools
8080
std::map<TkeyA,std::map<TkeyB,std::array<Tvalue,N>>> && ds_in);
8181

8282
template<typename Tcell>
83-
extern std::array<Tcell,3> cal_latvec_range(const double &rcut_times, const std::vector<double>& orb_cutoff);
83+
extern std::array<Tcell,3> cal_latvec_range(const double &rcut_times,
84+
const UnitCell &ucell,
85+
const std::vector<double>& orb_cutoff);
8486

8587
template<typename TA, typename Tcell, typename Tdata>
8688
extern std::map<int,std::map<int,std::map<Abfs::Vector3_Order<double>,RI::Tensor<Tdata>>>>
8789
get_CVws(
90+
const UnitCell &ucell,
8891
const std::map<TA,std::map<std::pair<TA,std::array<Tcell,3>>,RI::Tensor<Tdata>>> &CVs);
8992
template<typename TA, typename Tcell, typename Tdata>
9093
extern std::map<int,std::map<int,std::map<Abfs::Vector3_Order<double>,std::array<RI::Tensor<Tdata>,3>>>>
9194
get_dCVws(
95+
const UnitCell &ucell,
9296
const std::array<std::map<TA,std::map<std::pair<TA,std::array<Tcell,3>>,RI::Tensor<Tdata>>>,3> &dCVs);
9397

9498
template<typename TA, typename TC, typename Tdata>
9599
extern std::array<std::array<std::map<TA,std::map<std::pair<TA,TC>,RI::Tensor<Tdata>>>,3>,3>
96100
cal_dMRs(
101+
const UnitCell &ucell,
97102
const std::array<std::map<TA,std::map<std::pair<TA,TC>,RI::Tensor<Tdata>>>,3> &dMs);
98103

99104
using TC = std::array<int, 3>;

source/module_ri/LRI_CV_Tools.hpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,16 @@ LRI_CV_Tools::change_order(std::map<TkeyA,std::map<TkeyB,std::array<Tvalue,N>>>
245245

246246
template<typename Tcell>
247247
std::array<Tcell,3>
248-
LRI_CV_Tools::cal_latvec_range(const double &rcut_times, const std::vector<double>& orb_cutoff)
248+
LRI_CV_Tools::cal_latvec_range(const double &rcut_times,
249+
const UnitCell &ucell,
250+
const std::vector<double>& orb_cutoff)
249251
{
250252
double Rcut_max = 0;
251-
for(int T=0; T<GlobalC::ucell.ntype; ++T)
253+
for(int T=0; T<ucell.ntype; ++T)
252254
Rcut_max = std::max(Rcut_max, orb_cutoff[T]);
253255
const ModuleBase::Vector3<double> proj = ModuleBase::Mathzone::latvec_projection(
254-
std::array<ModuleBase::Vector3<double>,3>{GlobalC::ucell.a1, GlobalC::ucell.a2, GlobalC::ucell.a3});
255-
const ModuleBase::Vector3<double> latvec_times = Rcut_max * rcut_times / (proj * GlobalC::ucell.lat0);
256+
std::array<ModuleBase::Vector3<double>,3>{ucell.a1, ucell.a2, ucell.a3});
257+
const ModuleBase::Vector3<double> latvec_times = Rcut_max * rcut_times / (proj * ucell.lat0);
256258
const ModuleBase::Vector3<Tcell> latvec_times_ceil = {static_cast<Tcell>(std::ceil(latvec_times.x)),
257259
static_cast<Tcell>(std::ceil(latvec_times.y)),
258260
static_cast<Tcell>(std::ceil(latvec_times.z))};
@@ -263,23 +265,24 @@ LRI_CV_Tools::cal_latvec_range(const double &rcut_times, const std::vector<doubl
263265
template<typename TA, typename Tcell, typename Tdata>
264266
std::map<int,std::map<int,std::map<Abfs::Vector3_Order<double>,RI::Tensor<Tdata>>>>
265267
LRI_CV_Tools::get_CVws(
268+
const UnitCell &ucell,
266269
const std::map<TA,std::map<std::pair<TA,std::array<Tcell,3>>,RI::Tensor<Tdata>>> &CVs)
267270
{
268271
std::map<int,std::map<int,std::map<Abfs::Vector3_Order<double>,RI::Tensor<Tdata>>>> CVws;
269272
for(const auto &CVs_A : CVs)
270273
{
271274
const TA iat0 = CVs_A.first;
272-
const int it0 = GlobalC::ucell.iat2it[iat0];
273-
const int ia0 = GlobalC::ucell.iat2ia[iat0];
274-
const ModuleBase::Vector3<double> tau0 = GlobalC::ucell.atoms[it0].tau[ia0];
275+
const int it0 = ucell.iat2it[iat0];
276+
const int ia0 = ucell.iat2ia[iat0];
277+
const ModuleBase::Vector3<double> tau0 = ucell.atoms[it0].tau[ia0];
275278
for(const auto &CVs_B : CVs_A.second)
276279
{
277280
const TA iat1 = CVs_B.first.first;
278-
const int it1 = GlobalC::ucell.iat2it[iat1];
279-
const int ia1 = GlobalC::ucell.iat2ia[iat1];
281+
const int it1 = ucell.iat2it[iat1];
282+
const int ia1 = ucell.iat2ia[iat1];
280283
const std::array<int,3> &cell1 = CVs_B.first.second;
281-
const ModuleBase::Vector3<double> tau1 = GlobalC::ucell.atoms[it1].tau[ia1];
282-
const Abfs::Vector3_Order<double> R_delta = -tau0+tau1+(RI_Util::array3_to_Vector3(cell1)*GlobalC::ucell.latvec);
284+
const ModuleBase::Vector3<double> tau1 = ucell.atoms[it1].tau[ia1];
285+
const Abfs::Vector3_Order<double> R_delta = -tau0+tau1+(RI_Util::array3_to_Vector3(cell1)*ucell.latvec);
283286
CVws[it0][it1][R_delta] = CVs_B.second;
284287
}
285288
}
@@ -289,6 +292,7 @@ LRI_CV_Tools::get_CVws(
289292
template<typename TA, typename Tcell, typename Tdata>
290293
std::map<int,std::map<int,std::map<Abfs::Vector3_Order<double>,std::array<RI::Tensor<Tdata>,3>>>>
291294
LRI_CV_Tools::get_dCVws(
295+
const UnitCell &ucell,
292296
const std::array<std::map<TA,std::map<std::pair<TA,std::array<Tcell,3>>,RI::Tensor<Tdata>>>,3> &dCVs)
293297
{
294298
std::map<int,std::map<int,std::map<Abfs::Vector3_Order<double>,std::array<RI::Tensor<Tdata>,3>>>> dCVws;
@@ -297,17 +301,17 @@ LRI_CV_Tools::get_dCVws(
297301
for(const auto &dCVs_A : dCVs[ix])
298302
{
299303
const TA iat0 = dCVs_A.first;
300-
const int it0 = GlobalC::ucell.iat2it[iat0];
301-
const int ia0 = GlobalC::ucell.iat2ia[iat0];
302-
const ModuleBase::Vector3<double> tau0 = GlobalC::ucell.atoms[it0].tau[ia0];
304+
const int it0 = ucell.iat2it[iat0];
305+
const int ia0 = ucell.iat2ia[iat0];
306+
const ModuleBase::Vector3<double> tau0 = ucell.atoms[it0].tau[ia0];
303307
for(const auto &dCVs_B : dCVs_A.second)
304308
{
305309
const TA iat1 = dCVs_B.first.first;
306-
const int it1 = GlobalC::ucell.iat2it[iat1];
307-
const int ia1 = GlobalC::ucell.iat2ia[iat1];
310+
const int it1 = ucell.iat2it[iat1];
311+
const int ia1 = ucell.iat2ia[iat1];
308312
const std::array<int,3> &cell1 = dCVs_B.first.second;
309-
const ModuleBase::Vector3<double> tau1 = GlobalC::ucell.atoms[it1].tau[ia1];
310-
const Abfs::Vector3_Order<double> R_delta = -tau0+tau1+(RI_Util::array3_to_Vector3(cell1)*GlobalC::ucell.latvec);
313+
const ModuleBase::Vector3<double> tau1 = ucell.atoms[it1].tau[ia1];
314+
const Abfs::Vector3_Order<double> R_delta = -tau0+tau1+(RI_Util::array3_to_Vector3(cell1)*ucell.latvec);
311315
dCVws[it0][it1][R_delta][ix] = dCVs_B.second;
312316
}
313317
}
@@ -320,6 +324,7 @@ LRI_CV_Tools::get_dCVws(
320324
template<typename TA, typename TC, typename Tdata>
321325
std::array<std::array<std::map<TA,std::map<std::pair<TA,TC>,RI::Tensor<Tdata>>>,3>,3>
322326
LRI_CV_Tools::cal_dMRs(
327+
const UnitCell &ucell,
323328
const std::array<std::map<TA,std::map<std::pair<TA,TC>,RI::Tensor<Tdata>>>,3> &dMs)
324329
{
325330
auto get_R_delta = [&](const TA &iat0, const std::pair<TA,TC> &A1) -> std::array<Tdata,3>

source/module_ri/RPA_LRI.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void RPA_LRI<T, Tdata>::cal_rpa_cv(const UnitCell& ucell)
4040
}
4141
const std::array<Tcell, Ndim> period = {p_kv->nmp[0], p_kv->nmp[1], p_kv->nmp[2]};
4242

43-
const std::array<Tcell, Ndim> period_Vs = LRI_CV_Tools::cal_latvec_range<Tcell>(1 + this->info.ccp_rmesh_times, orb_cutoff_);
43+
const std::array<Tcell, Ndim> period_Vs = LRI_CV_Tools::cal_latvec_range<Tcell>(1 + this->info.ccp_rmesh_times, ucell,orb_cutoff_);
4444
const std::pair<std::vector<TA>, std::vector<std::vector<std::pair<TA, std::array<Tcell, Ndim>>>>> list_As_Vs
4545
= RI::Distribute_Equally::distribute_atoms(this->mpi_comm, atoms, period_Vs, 2, false);
4646

@@ -52,7 +52,7 @@ void RPA_LRI<T, Tdata>::cal_rpa_cv(const UnitCell& ucell)
5252
});
5353
this->Vs_period = RI::RI_Tools::cal_period(Vs, period);
5454

55-
const std::array<Tcell, Ndim> period_Cs = LRI_CV_Tools::cal_latvec_range<Tcell>(2, orb_cutoff_);
55+
const std::array<Tcell, Ndim> period_Cs = LRI_CV_Tools::cal_latvec_range<Tcell>(2, ucell,orb_cutoff_);
5656
const std::pair<std::vector<TA>, std::vector<std::vector<std::pair<TA, std::array<Tcell, Ndim>>>>> list_As_Cs
5757
= RI::Distribute_Equally::distribute_atoms_periods(this->mpi_comm, atoms, period_Cs, 2, false);
5858

0 commit comments

Comments
 (0)