Skip to content

Commit a2bf20d

Browse files
committed
update runtime check in Exx_LRI_Interface
1 parent 8054108 commit a2bf20d

File tree

3 files changed

+97
-74
lines changed

3 files changed

+97
-74
lines changed

source/module_ri/Exx_LRI.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,22 @@ class Exx_LRI
5353
Exx_LRI operator=(const Exx_LRI&) = delete;
5454
Exx_LRI operator=(Exx_LRI&&);
5555

56-
void reset_Cs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); }
57-
void reset_Vs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); }
58-
59-
void init(const MPI_Comm &mpi_comm_in,
60-
const UnitCell &ucell,
61-
const K_Vectors &kv_in,
62-
const LCAO_Orbitals& orb);
63-
void cal_exx_force(const int& nat);
64-
void cal_exx_stress(const double& omega, const double& lat0);
56+
void init(
57+
const MPI_Comm &mpi_comm_in,
58+
const UnitCell &ucell,
59+
const K_Vectors &kv_in,
60+
const LCAO_Orbitals& orb);
6561
void cal_exx_ions(const UnitCell& ucell, const bool write_cv = false);
66-
void cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
62+
void cal_exx_elec(
63+
const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
6764
const UnitCell& ucell,
6865
const Parallel_Orbitals& pv,
6966
const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr);
67+
void cal_exx_force(const int& nat);
68+
void cal_exx_stress(const double& omega, const double& lat0);
69+
70+
void reset_Cs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); }
71+
void reset_Vs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); }
7072
std::vector<std::vector<int>> get_abfs_nchis() const;
7173

7274
std::vector< std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>> Hexxs;

source/module_ri/Exx_LRI_interface.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ template<typename T, typename Tdata>
3131
class Exx_LRI_Interface
3232
{
3333
public:
34+
using TA = int;
3435
using TC = std::array<int, 3>;
35-
using TAC = std::pair<int, TC>;
36+
using TAC = std::pair<TA, TC>;
3637

3738
/// @brief Constructor for Exx_LRI_Interface
3839
/// @param exx_ptr
@@ -43,18 +44,33 @@ class Exx_LRI_Interface
4344
void write_Hexxs_cereal(const std::string& file_name) const;
4445
void read_Hexxs_cereal(const std::string& file_name);
4546

46-
std::vector<std::map<int, std::map<TAC, RI::Tensor<Tdata>>>>& get_Hexxs() const { return this->exx_ptr->Hexxs; }
47+
std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& get_Hexxs() const { return this->exx_ptr->Hexxs; }
4748
double &get_Eexx() const { return this->exx_ptr->Eexx; }
4849
ModuleBase::matrix &get_force() const { return this->exx_ptr->force_exx; }
4950
ModuleBase::matrix &get_stress() const { return this->exx_ptr->stress_exx; }
5051

5152
// Processes in ESolver_KS_LCAO
52-
/// @brief Exx_LRI::init()
53+
/// @brief in init: Exx_LRI::init()
5354
void init(const MPI_Comm &mpi_comm,
5455
const UnitCell &ucell,
5556
const K_Vectors &kv,
5657
const LCAO_Orbitals& orb);
5758

59+
/// @brief: in cal_exx_ions: Exx_LRI::cal_exx_ions()
60+
void cal_exx_ions(const UnitCell& ucell, const bool write_cv = false);
61+
62+
/// @brief: in cal_exx_elec: Exx_LRI::cal_exx_elec()
63+
void cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
64+
const UnitCell& ucell,
65+
const Parallel_Orbitals& pv,
66+
const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr);
67+
68+
/// @brief: in cal_exx_force: Exx_LRI::cal_exx_force()
69+
void cal_exx_force(const int& nat);
70+
71+
/// @brief: in cal_exx_stress: Exx_LRI::cal_exx_stress()
72+
void cal_exx_stress(const double& omega, const double& lat0);
73+
5874
// Processes in ESolver_KS_LCAO
5975
/// @brief in before_all_runners: set symmetry according to irreducible k-points
6076
/// since k-points are not reduced again after the variation of the cell and exx-symmetry must be consistent with k-points.
@@ -95,12 +111,6 @@ class Exx_LRI_Interface
95111
const double& etot,
96112
const double& scf_ene_thr);
97113

98-
/// @brief: in cal_exx_force: Exx_LRI::cal_exx_force()
99-
void cal_exx_force(const int& nat);
100-
101-
/// @brief: in cal_exx_stress: Exx_LRI::cal_exx_stress()
102-
void cal_exx_stress(const double& omega, const double& lat0);
103-
104114
int two_level_step = 0;
105115
double etot_last_outer_loop = 0.0;
106116
elecstate::DensityMatrix<T, double>* dm_last_step;

source/module_ri/Exx_LRI_interface.hpp

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,61 @@ void Exx_LRI_Interface<T, Tdata>::init(const MPI_Comm &mpi_comm,
5151
this->flag_finish.init = true;
5252
}
5353

54+
template<typename T, typename Tdata>
55+
void Exx_LRI_Interface<T, Tdata>::cal_exx_ions(const UnitCell& ucell, const bool write_cv)
56+
{
57+
ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_ions");
58+
if(!this->flag_finish.init)
59+
{ throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
60+
61+
this->exx_ptr->cal_exx_ions(ucell, write_cv);
62+
63+
this->flag_finish.ions = true;
64+
}
65+
66+
template<typename T, typename Tdata>
67+
void Exx_LRI_Interface<T, Tdata>::cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
68+
const UnitCell& ucell,
69+
const Parallel_Orbitals& pv,
70+
const ModuleSymmetry::Symmetry_rotation* p_symrot)
71+
{
72+
ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_elec");
73+
if(!this->flag_finish.init || !this->flag_finish.ions)
74+
{ throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
75+
76+
this->exx_ptr->cal_exx_elec(Ds, ucell, pv, p_symrot);
77+
78+
this->flag_finish.elec = true;
79+
}
80+
81+
template<typename T, typename Tdata>
82+
void Exx_LRI_Interface<T, Tdata>::cal_exx_force(const int& nat)
83+
{
84+
ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_force");
85+
if(!this->flag_finish.init || !this->flag_finish.ions)
86+
{ throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
87+
if(!this->flag_finish.elec)
88+
{ throw std::runtime_error("Exx Hamiltonian unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
89+
90+
this->exx_ptr->cal_exx_force(nat);
91+
92+
this->flag_finish.force = true;
93+
}
94+
95+
template<typename T, typename Tdata>
96+
void Exx_LRI_Interface<T, Tdata>::cal_exx_stress(const double& omega, const double& lat0)
97+
{
98+
ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_stress");
99+
if(!this->flag_finish.init || !this->flag_finish.ions)
100+
{ throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
101+
if(!this->flag_finish.elec)
102+
{ throw std::runtime_error("Exx Hamiltonian unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
103+
104+
this->exx_ptr->cal_exx_stress(omega, lat0);
105+
106+
this->flag_finish.stress = true;
107+
}
108+
54109
template<typename T, typename Tdata>
55110
void Exx_LRI_Interface<T, Tdata>::exx_before_all_runners(const K_Vectors& kv, const UnitCell& ucell, const Parallel_2D& pv)
56111
{
@@ -102,10 +157,7 @@ void Exx_LRI_Interface<T, Tdata>::exx_beforescf(const int istep,
102157
}
103158
}
104159

105-
if(!this->flag_finish.init)
106-
{ throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
107-
this->exx_ptr->cal_exx_ions(ucell,PARAM.inp.out_ri_cv);
108-
this->flag_finish.ions = true;
160+
this->cal_exx_ions(ucell,PARAM.inp.out_ri_cv);
109161
}
110162

111163
if (Exx_Abfs::Jle::generate_matrix)
@@ -124,11 +176,11 @@ void Exx_LRI_Interface<T, Tdata>::exx_beforescf(const int istep,
124176
{this->mix_DMk_2D.set_nks(kv.get_nkstot_full() * (PARAM.inp.nspin == 2 ? 2 : 1), PARAM.globalv.gamma_only_local);}
125177
else
126178
{this->mix_DMk_2D.set_nks(kv.get_nks(), PARAM.globalv.gamma_only_local);}
127-
if(GlobalC::exx_info.info_global.separate_loop) {
128-
this->mix_DMk_2D.set_mixing(nullptr);
129-
} else {
130-
this->mix_DMk_2D.set_mixing(chgmix.get_mixing());
131-
}
179+
180+
if(GlobalC::exx_info.info_global.separate_loop)
181+
{ this->mix_DMk_2D.set_mixing(nullptr); }
182+
else
183+
{ this->mix_DMk_2D.set_mixing(chgmix.get_mixing()); }
132184
// for exx two_level scf
133185
this->two_level_step = 0;
134186
}
@@ -151,24 +203,19 @@ void Exx_LRI_Interface<T, Tdata>::exx_eachiterinit(const int istep,
151203
const bool flag_restart = (iter == 1) ? true : false;
152204
auto cal = [this, &ucell,&kv, &flag_restart](const elecstate::DensityMatrix<T, double>& dm_in)
153205
{
154-
if(!this->flag_finish.init || !this->flag_finish.ions)
155-
{ throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
156-
157206
if (this->exx_spacegroup_symmetry)
158207
{ this->mix_DMk_2D.mix(symrot_.restore_dm(kv,dm_in.get_DMK_vector(), *dm_in.get_paraV_pointer()), flag_restart); }
159208
else
160209
{ this->mix_DMk_2D.mix(dm_in.get_DMK_vector(), flag_restart); }
161-
const std::vector<std::map<int,std::map<std::pair<int, std::array<int, 3>>,RI::Tensor<Tdata>>>>
210+
const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>
162211
Ds = PARAM.globalv.gamma_only_local
163212
? RI_2D_Comm::split_m2D_ktoR<Tdata>(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_gamma_out(), *dm_in.get_paraV_pointer(), PARAM.inp.nspin)
164213
: RI_2D_Comm::split_m2D_ktoR<Tdata>(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_k_out(), *dm_in.get_paraV_pointer(), PARAM.inp.nspin, this->exx_spacegroup_symmetry);
165214

166215
if (this->exx_spacegroup_symmetry && GlobalC::exx_info.info_global.exx_symmetry_realspace)
167-
{ this->exx_ptr->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer(), &this->symrot_); }
216+
{ this->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer(), &this->symrot_); }
168217
else
169-
{ this->exx_ptr->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer()); }
170-
171-
this->flag_finish.elec = true;
218+
{ this->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer()); }
172219
};
173220

174221
if(istep > 0 && flag_restart)
@@ -335,9 +382,6 @@ bool Exx_LRI_Interface<T, Tdata>::exx_after_converge(
335382
}
336383
else
337384
{
338-
if(!this->flag_finish.init || !this->flag_finish.ions)
339-
{ throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
340-
341385
this->etot_last_outer_loop = etot;
342386
// update exx and redo scf
343387
if (this->two_level_step == 0)
@@ -361,16 +405,11 @@ bool Exx_LRI_Interface<T, Tdata>::exx_after_converge(
361405
: RI_2D_Comm::split_m2D_ktoR<Tdata>(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_k_out(), *dm.get_paraV_pointer(), nspin, this->exx_spacegroup_symmetry);
362406

363407
if (this->exx_spacegroup_symmetry && GlobalC::exx_info.info_global.exx_symmetry_realspace)
364-
{
365-
this->exx_ptr->cal_exx_elec(Ds, ucell, *dm.get_paraV_pointer(), &this->symrot_);
366-
}
408+
{ this->cal_exx_elec(Ds, ucell, *dm.get_paraV_pointer(), &this->symrot_); }
367409
else
368-
{
369-
this->exx_ptr->cal_exx_elec(Ds, ucell, *dm.get_paraV_pointer()); // restore DM but not Hexx
370-
}
410+
{ this->cal_exx_elec(Ds, ucell, *dm.get_paraV_pointer()); } // restore DM but not Hexx
371411
iter = 0;
372412
this->two_level_step++;
373-
this->flag_finish.elec = true;
374413

375414
timeval t_end; gettimeofday(&t_end, nullptr);
376415
std::cout << "and rerun SCF\t"
@@ -384,32 +423,4 @@ bool Exx_LRI_Interface<T, Tdata>::exx_after_converge(
384423
return true;
385424
}
386425

387-
template<typename T, typename Tdata>
388-
void Exx_LRI_Interface<T, Tdata>::cal_exx_force(const int& nat)
389-
{
390-
ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_force");
391-
if(!this->flag_finish.init || !this->flag_finish.ions)
392-
{ throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
393-
if(!this->flag_finish.elec)
394-
{ throw std::runtime_error("Exx Hamiltonian unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
395-
396-
this->exx_ptr->cal_exx_force(nat);
397-
398-
this->flag_finish.force = true;
399-
}
400-
401-
template<typename T, typename Tdata>
402-
void Exx_LRI_Interface<T, Tdata>::cal_exx_stress(const double& omega, const double& lat0)
403-
{
404-
ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_stress");
405-
if(!this->flag_finish.init || !this->flag_finish.ions)
406-
{ throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
407-
if(!this->flag_finish.elec)
408-
{ throw std::runtime_error("Exx Hamiltonian unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); }
409-
410-
this->exx_ptr->cal_exx_stress(omega, lat0);
411-
412-
this->flag_finish.stress = true;
413-
}
414-
415426
#endif

0 commit comments

Comments
 (0)