Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,11 @@ ESolver_KS_LCAO<TK, TR>::ESolver_KS_LCAO()
// because some members like two_level_step are used outside if(cal_exx)
if (GlobalC::exx_info.info_ri.real_number)
{
this->exx_lri_double = std::make_shared<Exx_LRI<double>>(GlobalC::exx_info.info_ri);
this->exd = std::make_shared<Exx_LRI_Interface<TK, double>>(exx_lri_double);
this->exd = std::make_shared<Exx_LRI_Interface<TK, double>>(GlobalC::exx_info.info_ri);
}
else
{
this->exx_lri_complex = std::make_shared<Exx_LRI<std::complex<double>>>(GlobalC::exx_info.info_ri);
this->exc = std::make_shared<Exx_LRI_Interface<TK, std::complex<double>>>(exx_lri_complex);
this->exc = std::make_shared<Exx_LRI_Interface<TK, std::complex<double>>>(GlobalC::exx_info.info_ri);
}
#endif
}
Expand Down Expand Up @@ -198,12 +196,12 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
// initialize 2-center radial tables for EXX-LRI
if (GlobalC::exx_info.info_ri.real_number)
{
this->exx_lri_double->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
this->exd->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
this->exd->exx_before_all_runners(this->kv, ucell, this->pv);
}
else
{
this->exx_lri_complex->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
this->exc->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
this->exc->exx_before_all_runners(this->kv, ucell, this->pv);
}
}
Expand Down Expand Up @@ -351,8 +349,8 @@ void ESolver_KS_LCAO<TK, TR>::cal_force(UnitCell& ucell, ModuleBase::matrix& for
this->ld,
#endif
#ifdef __EXX
*this->exx_lri_double,
*this->exx_lri_complex,
*this->exd,
*this->exc,
#endif
&ucell.symm);

Expand Down Expand Up @@ -461,8 +459,8 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
this->gd
#ifdef __EXX
,
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
this->exd ? &this->exd->get_Hexxs() : nullptr,
this->exc ? &this->exc->get_Hexxs() : nullptr
#endif
);
}
Expand All @@ -484,8 +482,8 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
this->gd
#ifdef __EXX
,
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
this->exd ? &this->exd->get_Hexxs() : nullptr,
this->exc ? &this->exc->get_Hexxs() : nullptr
#endif
);
}
Expand Down Expand Up @@ -514,8 +512,8 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
this->two_center_bundle_
#ifdef __EXX
,
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
this->exd ? &this->exd->get_Hexxs() : nullptr,
this->exc ? &this->exc->get_Hexxs() : nullptr
#endif
);
}
Expand Down
2 changes: 0 additions & 2 deletions source/module_esolver/esolver_ks_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ class ESolver_KS_LCAO : public ESolver_KS<TK>
#ifdef __EXX
std::shared_ptr<Exx_LRI_Interface<TK, double>> exd = nullptr;
std::shared_ptr<Exx_LRI_Interface<TK, std::complex<double>>> exc = nullptr;
std::shared_ptr<Exx_LRI<double>> exx_lri_double = nullptr;
std::shared_ptr<Exx_LRI<std::complex<double>>> exx_lri_complex = nullptr;
#endif

friend class LR::ESolver_LR<double, double>;
Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/lcao_before_scf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
,
istep,
GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step,
GlobalC::exx_info.info_ri.real_number ? &exx_lri_double->Hexxs : nullptr,
GlobalC::exx_info.info_ri.real_number ? nullptr : &exx_lri_complex->Hexxs
GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr,
GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs()
#endif
);
}
Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/lcao_others.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
,
istep,
GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step,
GlobalC::exx_info.info_ri.real_number ? &exx_lri_double->Hexxs : nullptr,
GlobalC::exx_info.info_ri.real_number ? nullptr : &exx_lri_complex->Hexxs
GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr,
GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs()
#endif
);
}
Expand Down
20 changes: 10 additions & 10 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
LCAO_Deepks<T>& ld,
#endif
#ifdef __EXX
Exx_LRI<double>& exx_lri_double,
Exx_LRI<std::complex<double>>& exx_lri_complex,
Exx_LRI_Interface<T, double>& exd,
Exx_LRI_Interface<T, std::complex<double>>& exc,
#endif
ModuleSymmetry::Symmetry* symm)
{
Expand Down Expand Up @@ -377,26 +377,26 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
{
if (GlobalC::exx_info.info_ri.real_number)
{
exx_lri_double.cal_exx_force(ucell.nat);
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_double.force_exx;
exd.cal_exx_force(ucell.nat);
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exd.get_force();
}
else
{
exx_lri_complex.cal_exx_force(ucell.nat);
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_complex.force_exx;
exc.cal_exx_force(ucell.nat);
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exc.get_force();
}
}
if (isstress)
{
if (GlobalC::exx_info.info_ri.real_number)
{
exx_lri_double.cal_exx_stress(ucell.omega, ucell.lat0);
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_double.stress_exx;
exd.cal_exx_stress(ucell.omega, ucell.lat0);
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exd.get_stress();
}
else
{
exx_lri_complex.cal_exx_stress(ucell.omega, ucell.lat0);
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_complex.stress_exx;
exc.cal_exx_stress(ucell.omega, ucell.lat0);
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exc.get_stress();
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "module_io/input_conv.h"
#include "module_psi/psi.h"
#ifdef __EXX
#include "module_ri/Exx_LRI.h"
#include "module_ri/Exx_LRI_interface.h"
#endif
#include "force_stress_arrays.h"
#include "module_hamilt_lcao/module_gint/gint_gamma.h"
Expand Down Expand Up @@ -53,8 +53,8 @@ class Force_Stress_LCAO
LCAO_Deepks<T>& ld,
#endif
#ifdef __EXX
Exx_LRI<double>& exx_lri_double,
Exx_LRI<std::complex<double>>& exx_lri_complex,
Exx_LRI_Interface<T, double>& exd,
Exx_LRI_Interface<T, std::complex<double>>& exc,
#endif
ModuleSymmetry::Symmetry* symm);

Expand Down
8 changes: 4 additions & 4 deletions source/module_lr/esolver_lrtd_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,10 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol
{
// if the same kernel is calculated in the esolver_ks, move it
std::string dft_functional = LR_Util::tolower(input.dft_functional);
if (ks_sol.exx_lri_double && std::is_same<T, double>::value && xc_kernel == dft_functional) {
this->move_exx_lri(ks_sol.exx_lri_double);
} else if (ks_sol.exx_lri_complex && std::is_same<T, std::complex<double>>::value && xc_kernel == dft_functional) {
this->move_exx_lri(ks_sol.exx_lri_complex);
if (ks_sol.exd && std::is_same<T, double>::value && xc_kernel == dft_functional) {
this->move_exx_lri(ks_sol.exd->exx_ptr);
} else if (ks_sol.exc && std::is_same<T, std::complex<double>>::value && xc_kernel == dft_functional) {
this->move_exx_lri(ks_sol.exc->exx_ptr);
} else // construct C, V from scratch
{
// set ccp_type according to the xc_kernel
Expand Down
68 changes: 35 additions & 33 deletions source/module_ri/Exx_LRI.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@
#include "module_exx_symmetry/symmetry_rotation.h"

class Parallel_Orbitals;

template<typename T, typename Tdata>
class RPA_LRI;

template<typename T, typename Tdata>
class Exx_LRI_Interface;

namespace LR
{
template<typename T, typename TR>
class ESolver_LR;
namespace LR
{
template<typename T, typename TR>
class ESolver_LR;

template<typename T>
class OperatorLREXX;
}
template<typename T>
class OperatorLREXX;
}

template<typename Tdata>
class Exx_LRI
Expand All @@ -49,37 +49,39 @@ class Exx_LRI
using TatomR = std::array<double,Ndim>; // tmp

public:
Exx_LRI(const Exx_Info::Exx_Info_RI& info_in) :info(info_in) {}
Exx_LRI operator=(const Exx_LRI&) = delete;
Exx_LRI operator=(Exx_LRI&&);

void reset_Cs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); }
void reset_Vs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); }

void init(const MPI_Comm &mpi_comm_in,
const UnitCell &ucell,
const K_Vectors &kv_in,
const LCAO_Orbitals& orb);
void cal_exx_force(const int& nat);
void cal_exx_stress(const double& omega, const double& lat0);
Exx_LRI(const Exx_Info::Exx_Info_RI& info_in) :info(info_in) {}
Exx_LRI operator=(const Exx_LRI&) = delete;
Exx_LRI operator=(Exx_LRI&&);

void init(
const MPI_Comm &mpi_comm_in,
const UnitCell &ucell,
const K_Vectors &kv_in,
const LCAO_Orbitals& orb);
void cal_exx_ions(const UnitCell& ucell, const bool write_cv = false);
void cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
void cal_exx_elec(
const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
const UnitCell& ucell,
const Parallel_Orbitals& pv,
const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr);
std::vector<std::vector<int>> get_abfs_nchis() const;
const Parallel_Orbitals& pv,
const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr);
void cal_exx_force(const int& nat);
void cal_exx_stress(const double& omega, const double& lat0);

void reset_Cs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); }
void reset_Vs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); }
//std::vector<std::vector<int>> get_abfs_nchis() const;

std::vector< std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>> Hexxs;
double Eexx;
double Eexx;
ModuleBase::matrix force_exx;
ModuleBase::matrix stress_exx;


private:
const Exx_Info::Exx_Info_RI &info;
MPI_Comm mpi_comm;
const K_Vectors *p_kv = nullptr;
std::vector<double> orb_cutoff_;
std::vector<double> orb_cutoff_;

std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> lcaos;
std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> abfs;
Expand All @@ -89,16 +91,16 @@ class Exx_LRI
RI::Exx<TA,Tcell,Ndim,Tdata> exx_lri;

void post_process_Hexx( std::map<TA, std::map<TAC, RI::Tensor<Tdata>>> &Hexxs_io ) const;
double post_process_Eexx(const double& Eexx_in) const;
double post_process_Eexx(const double& Eexx_in) const;

friend class RPA_LRI<double, Tdata>;
friend class RPA_LRI<std::complex<double>, Tdata>;
friend class Exx_LRI_Interface<double, Tdata>;
friend class Exx_LRI_Interface<std::complex<double>, Tdata>;
friend class LR::ESolver_LR<double, double>;
friend class LR::ESolver_LR<std::complex<double>, double>;
friend class LR::OperatorLREXX<double>;
friend class LR::OperatorLREXX<std::complex<double>>;
friend class LR::ESolver_LR<double, double>;
friend class LR::ESolver_LR<std::complex<double>, double>;
friend class LR::OperatorLREXX<double>;
friend class LR::OperatorLREXX<std::complex<double>>;
};

#include "Exx_LRI.hpp"
Expand Down
15 changes: 6 additions & 9 deletions source/module_ri/Exx_LRI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
#include <string>

template<typename Tdata>
void Exx_LRI<Tdata>::init(const MPI_Comm &mpi_comm_in,
void Exx_LRI<Tdata>::init(const MPI_Comm &mpi_comm_in,
const UnitCell &ucell,
const K_Vectors &kv_in,
const K_Vectors &kv_in,
const LCAO_Orbitals& orb)
{
ModuleBase::TITLE("Exx_LRI","init");
Expand Down Expand Up @@ -130,7 +130,7 @@ void Exx_LRI<Tdata>::cal_exx_ions(const UnitCell& ucell,
this->exx_lri.set_parallel(this->mpi_comm, atoms_pos, latvec, period);

// std::max(3) for gamma_only, list_A2 should contain cell {-1,0,1}. In the future distribute will be neighbour.
const std::array<Tcell,Ndim> period_Vs = LRI_CV_Tools::cal_latvec_range<Tcell>(1+this->info.ccp_rmesh_times, ucell, orb_cutoff_);
const std::array<Tcell,Ndim> period_Vs = LRI_CV_Tools::cal_latvec_range<Tcell>(1+this->info.ccp_rmesh_times, ucell, orb_cutoff_);
const std::pair<std::vector<TA>, std::vector<std::vector<std::pair<TA,std::array<Tcell,Ndim>>>>>
list_As_Vs = RI::Distribute_Equally::distribute_atoms_periods(this->mpi_comm, atoms, period_Vs, 2, false);

Expand Down Expand Up @@ -237,7 +237,7 @@ void Exx_LRI<Tdata>::cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, R
}
this->Eexx = post_process_Eexx(this->Eexx);
this->exx_lri.set_symmetry(false, {});
ModuleBase::timer::tick("Exx_LRI", "cal_exx_elec");
ModuleBase::timer::tick("Exx_LRI", "cal_exx_elec");
}

template<typename Tdata>
Expand Down Expand Up @@ -283,11 +283,6 @@ void Exx_LRI<Tdata>::cal_exx_force(const int& nat)
ModuleBase::TITLE("Exx_LRI","cal_exx_force");
ModuleBase::timer::tick("Exx_LRI", "cal_exx_force");

if (!this->exx_lri.flag_finish.D)
{
ModuleBase::WARNING_QUIT("Force_Stress_LCAO", "Cannot calculate EXX force when the first PBE loop is not converged.");
}

this->force_exx.create(nat, Ndim);
for(int is=0; is<PARAM.inp.nspin; ++is)
{
Expand Down Expand Up @@ -328,6 +323,7 @@ void Exx_LRI<Tdata>::cal_exx_stress(const double& omega, const double& lat0)
ModuleBase::timer::tick("Exx_LRI", "cal_exx_stress");
}

/*
template<typename Tdata>
std::vector<std::vector<int>> Exx_LRI<Tdata>::get_abfs_nchis() const
{
Expand All @@ -341,5 +337,6 @@ std::vector<std::vector<int>> Exx_LRI<Tdata>::get_abfs_nchis() const
}
return abfs_nchis;
}
*/

#endif
Loading
Loading