diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index 666f587a99..297a89f40e 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -83,13 +83,11 @@ ESolver_KS_LCAO::ESolver_KS_LCAO() // because some members like two_level_step are used outside if(cal_exx) if (GlobalC::exx_info.info_ri.real_number) { - this->exx_lri_double = std::make_shared>(GlobalC::exx_info.info_ri); - this->exd = std::make_shared>(exx_lri_double); + this->exd = std::make_shared>(GlobalC::exx_info.info_ri); } else { - this->exx_lri_complex = std::make_shared>>(GlobalC::exx_info.info_ri); - this->exc = std::make_shared>>(exx_lri_complex); + this->exc = std::make_shared>>(GlobalC::exx_info.info_ri); } #endif } @@ -198,12 +196,12 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa // initialize 2-center radial tables for EXX-LRI if (GlobalC::exx_info.info_ri.real_number) { - this->exx_lri_double->init(MPI_COMM_WORLD, ucell, this->kv, orb_); + this->exd->init(MPI_COMM_WORLD, ucell, this->kv, orb_); this->exd->exx_before_all_runners(this->kv, ucell, this->pv); } else { - this->exx_lri_complex->init(MPI_COMM_WORLD, ucell, this->kv, orb_); + this->exc->init(MPI_COMM_WORLD, ucell, this->kv, orb_); this->exc->exx_before_all_runners(this->kv, ucell, this->pv); } } @@ -351,8 +349,8 @@ void ESolver_KS_LCAO::cal_force(UnitCell& ucell, ModuleBase::matrix& for this->ld, #endif #ifdef __EXX - *this->exx_lri_double, - *this->exx_lri_complex, + *this->exd, + *this->exc, #endif &ucell.symm); @@ -461,8 +459,8 @@ void ESolver_KS_LCAO::after_all_runners(UnitCell& ucell) this->gd #ifdef __EXX , - this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr, - this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr + this->exd ? &this->exd->get_Hexxs() : nullptr, + this->exc ? &this->exc->get_Hexxs() : nullptr #endif ); } @@ -484,8 +482,8 @@ void ESolver_KS_LCAO::after_all_runners(UnitCell& ucell) this->gd #ifdef __EXX , - this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr, - this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr + this->exd ? &this->exd->get_Hexxs() : nullptr, + this->exc ? &this->exc->get_Hexxs() : nullptr #endif ); } @@ -514,8 +512,8 @@ void ESolver_KS_LCAO::after_all_runners(UnitCell& ucell) this->two_center_bundle_ #ifdef __EXX , - this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr, - this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr + this->exd ? &this->exd->get_Hexxs() : nullptr, + this->exc ? &this->exc->get_Hexxs() : nullptr #endif ); } diff --git a/source/module_esolver/esolver_ks_lcao.h b/source/module_esolver/esolver_ks_lcao.h index 774a3bc16b..6bbbb9a2fe 100644 --- a/source/module_esolver/esolver_ks_lcao.h +++ b/source/module_esolver/esolver_ks_lcao.h @@ -116,8 +116,6 @@ class ESolver_KS_LCAO : public ESolver_KS #ifdef __EXX std::shared_ptr> exd = nullptr; std::shared_ptr>> exc = nullptr; - std::shared_ptr> exx_lri_double = nullptr; - std::shared_ptr>> exx_lri_complex = nullptr; #endif friend class LR::ESolver_LR; diff --git a/source/module_esolver/lcao_before_scf.cpp b/source/module_esolver/lcao_before_scf.cpp index 9a23bbfc91..ca5a411ac7 100644 --- a/source/module_esolver/lcao_before_scf.cpp +++ b/source/module_esolver/lcao_before_scf.cpp @@ -159,8 +159,8 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) , istep, GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step, - GlobalC::exx_info.info_ri.real_number ? &exx_lri_double->Hexxs : nullptr, - GlobalC::exx_info.info_ri.real_number ? nullptr : &exx_lri_complex->Hexxs + GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr, + GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs() #endif ); } diff --git a/source/module_esolver/lcao_others.cpp b/source/module_esolver/lcao_others.cpp index 6390d59ddc..af195e7ac5 100644 --- a/source/module_esolver/lcao_others.cpp +++ b/source/module_esolver/lcao_others.cpp @@ -217,8 +217,8 @@ void ESolver_KS_LCAO::others(UnitCell& ucell, const int istep) , istep, GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step, - GlobalC::exx_info.info_ri.real_number ? &exx_lri_double->Hexxs : nullptr, - GlobalC::exx_info.info_ri.real_number ? nullptr : &exx_lri_complex->Hexxs + GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr, + GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs() #endif ); } diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp b/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp index 696b93a5ed..b0fcd3354c 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp +++ b/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp @@ -55,8 +55,8 @@ void Force_Stress_LCAO::getForceStress(UnitCell& ucell, LCAO_Deepks& ld, #endif #ifdef __EXX - Exx_LRI& exx_lri_double, - Exx_LRI>& exx_lri_complex, + Exx_LRI_Interface& exd, + Exx_LRI_Interface>& exc, #endif ModuleSymmetry::Symmetry* symm) { @@ -377,26 +377,26 @@ void Force_Stress_LCAO::getForceStress(UnitCell& ucell, { if (GlobalC::exx_info.info_ri.real_number) { - exx_lri_double.cal_exx_force(ucell.nat); - force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_double.force_exx; + exd.cal_exx_force(ucell.nat); + force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exd.get_force(); } else { - exx_lri_complex.cal_exx_force(ucell.nat); - force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_complex.force_exx; + exc.cal_exx_force(ucell.nat); + force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exc.get_force(); } } if (isstress) { if (GlobalC::exx_info.info_ri.real_number) { - exx_lri_double.cal_exx_stress(ucell.omega, ucell.lat0); - stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_double.stress_exx; + exd.cal_exx_stress(ucell.omega, ucell.lat0); + stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exd.get_stress(); } else { - exx_lri_complex.cal_exx_stress(ucell.omega, ucell.lat0); - stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_complex.stress_exx; + exc.cal_exx_stress(ucell.omega, ucell.lat0); + stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exc.get_stress(); } } } diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h b/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h index f1ba8f98b2..4e7ae06c9f 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h +++ b/source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h @@ -11,7 +11,7 @@ #include "module_io/input_conv.h" #include "module_psi/psi.h" #ifdef __EXX -#include "module_ri/Exx_LRI.h" +#include "module_ri/Exx_LRI_interface.h" #endif #include "force_stress_arrays.h" #include "module_hamilt_lcao/module_gint/gint_gamma.h" @@ -53,8 +53,8 @@ class Force_Stress_LCAO LCAO_Deepks& ld, #endif #ifdef __EXX - Exx_LRI& exx_lri_double, - Exx_LRI>& exx_lri_complex, + Exx_LRI_Interface& exd, + Exx_LRI_Interface>& exc, #endif ModuleSymmetry::Symmetry* symm); diff --git a/source/module_lr/esolver_lrtd_lcao.cpp b/source/module_lr/esolver_lrtd_lcao.cpp index 1db10b5caf..88df295962 100644 --- a/source/module_lr/esolver_lrtd_lcao.cpp +++ b/source/module_lr/esolver_lrtd_lcao.cpp @@ -272,10 +272,10 @@ LR::ESolver_LR::ESolver_LR(ModuleESolver::ESolver_KS_LCAO&& ks_sol { // if the same kernel is calculated in the esolver_ks, move it std::string dft_functional = LR_Util::tolower(input.dft_functional); - if (ks_sol.exx_lri_double && std::is_same::value && xc_kernel == dft_functional) { - this->move_exx_lri(ks_sol.exx_lri_double); - } else if (ks_sol.exx_lri_complex && std::is_same>::value && xc_kernel == dft_functional) { - this->move_exx_lri(ks_sol.exx_lri_complex); + if (ks_sol.exd && std::is_same::value && xc_kernel == dft_functional) { + this->move_exx_lri(ks_sol.exd->exx_ptr); + } else if (ks_sol.exc && std::is_same>::value && xc_kernel == dft_functional) { + this->move_exx_lri(ks_sol.exc->exx_ptr); } else // construct C, V from scratch { // set ccp_type according to the xc_kernel diff --git a/source/module_ri/Exx_LRI.h b/source/module_ri/Exx_LRI.h index cbcef837a9..ae53393ee6 100644 --- a/source/module_ri/Exx_LRI.h +++ b/source/module_ri/Exx_LRI.h @@ -21,21 +21,21 @@ #include "module_exx_symmetry/symmetry_rotation.h" class Parallel_Orbitals; - + template class RPA_LRI; template class Exx_LRI_Interface; - namespace LR - { - template - class ESolver_LR; + namespace LR + { + template + class ESolver_LR; - template - class OperatorLREXX; - } + template + class OperatorLREXX; + } template class Exx_LRI @@ -49,37 +49,39 @@ class Exx_LRI using TatomR = std::array; // tmp public: - Exx_LRI(const Exx_Info::Exx_Info_RI& info_in) :info(info_in) {} - Exx_LRI operator=(const Exx_LRI&) = delete; - Exx_LRI operator=(Exx_LRI&&); - - void reset_Cs(const std::map>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); } - void reset_Vs(const std::map>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); } - - void init(const MPI_Comm &mpi_comm_in, - const UnitCell &ucell, - const K_Vectors &kv_in, - const LCAO_Orbitals& orb); - void cal_exx_force(const int& nat); - void cal_exx_stress(const double& omega, const double& lat0); + Exx_LRI(const Exx_Info::Exx_Info_RI& info_in) :info(info_in) {} + Exx_LRI operator=(const Exx_LRI&) = delete; + Exx_LRI operator=(Exx_LRI&&); + + void init( + const MPI_Comm &mpi_comm_in, + const UnitCell &ucell, + const K_Vectors &kv_in, + const LCAO_Orbitals& orb); void cal_exx_ions(const UnitCell& ucell, const bool write_cv = false); - void cal_exx_elec(const std::vector>>>& Ds, + void cal_exx_elec( + const std::vector>>>& Ds, const UnitCell& ucell, - const Parallel_Orbitals& pv, - const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr); - std::vector> get_abfs_nchis() const; + const Parallel_Orbitals& pv, + const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr); + void cal_exx_force(const int& nat); + void cal_exx_stress(const double& omega, const double& lat0); + + void reset_Cs(const std::map>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); } + void reset_Vs(const std::map>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); } + //std::vector> get_abfs_nchis() const; std::vector< std::map>>> Hexxs; - double Eexx; + double Eexx; ModuleBase::matrix force_exx; ModuleBase::matrix stress_exx; - + private: const Exx_Info::Exx_Info_RI &info; MPI_Comm mpi_comm; const K_Vectors *p_kv = nullptr; - std::vector orb_cutoff_; + std::vector orb_cutoff_; std::vector>> lcaos; std::vector>> abfs; @@ -89,16 +91,16 @@ class Exx_LRI RI::Exx exx_lri; void post_process_Hexx( std::map>> &Hexxs_io ) const; - double post_process_Eexx(const double& Eexx_in) const; + double post_process_Eexx(const double& Eexx_in) const; friend class RPA_LRI; friend class RPA_LRI, Tdata>; friend class Exx_LRI_Interface; friend class Exx_LRI_Interface, Tdata>; - friend class LR::ESolver_LR; - friend class LR::ESolver_LR, double>; - friend class LR::OperatorLREXX; - friend class LR::OperatorLREXX>; + friend class LR::ESolver_LR; + friend class LR::ESolver_LR, double>; + friend class LR::OperatorLREXX; + friend class LR::OperatorLREXX>; }; #include "Exx_LRI.hpp" diff --git a/source/module_ri/Exx_LRI.hpp b/source/module_ri/Exx_LRI.hpp index 26c494a805..0ac61c6f8e 100644 --- a/source/module_ri/Exx_LRI.hpp +++ b/source/module_ri/Exx_LRI.hpp @@ -26,9 +26,9 @@ #include template -void Exx_LRI::init(const MPI_Comm &mpi_comm_in, +void Exx_LRI::init(const MPI_Comm &mpi_comm_in, const UnitCell &ucell, - const K_Vectors &kv_in, + const K_Vectors &kv_in, const LCAO_Orbitals& orb) { ModuleBase::TITLE("Exx_LRI","init"); @@ -130,7 +130,7 @@ void Exx_LRI::cal_exx_ions(const UnitCell& ucell, this->exx_lri.set_parallel(this->mpi_comm, atoms_pos, latvec, period); // std::max(3) for gamma_only, list_A2 should contain cell {-1,0,1}. In the future distribute will be neighbour. - const std::array period_Vs = LRI_CV_Tools::cal_latvec_range(1+this->info.ccp_rmesh_times, ucell, orb_cutoff_); + const std::array period_Vs = LRI_CV_Tools::cal_latvec_range(1+this->info.ccp_rmesh_times, ucell, orb_cutoff_); const std::pair, std::vector>>>> list_As_Vs = RI::Distribute_Equally::distribute_atoms_periods(this->mpi_comm, atoms, period_Vs, 2, false); @@ -237,7 +237,7 @@ void Exx_LRI::cal_exx_elec(const std::vectorEexx = post_process_Eexx(this->Eexx); this->exx_lri.set_symmetry(false, {}); - ModuleBase::timer::tick("Exx_LRI", "cal_exx_elec"); + ModuleBase::timer::tick("Exx_LRI", "cal_exx_elec"); } template @@ -283,11 +283,6 @@ void Exx_LRI::cal_exx_force(const int& nat) ModuleBase::TITLE("Exx_LRI","cal_exx_force"); ModuleBase::timer::tick("Exx_LRI", "cal_exx_force"); - if (!this->exx_lri.flag_finish.D) - { - ModuleBase::WARNING_QUIT("Force_Stress_LCAO", "Cannot calculate EXX force when the first PBE loop is not converged."); - } - this->force_exx.create(nat, Ndim); for(int is=0; is::cal_exx_stress(const double& omega, const double& lat0) ModuleBase::timer::tick("Exx_LRI", "cal_exx_stress"); } +/* template std::vector> Exx_LRI::get_abfs_nchis() const { @@ -341,5 +337,6 @@ std::vector> Exx_LRI::get_abfs_nchis() const } return abfs_nchis; } +*/ #endif diff --git a/source/module_ri/Exx_LRI_interface.h b/source/module_ri/Exx_LRI_interface.h index 82ec91b181..765bc7f627 100644 --- a/source/module_ri/Exx_LRI_interface.h +++ b/source/module_ri/Exx_LRI_interface.h @@ -31,25 +31,51 @@ template class Exx_LRI_Interface { public: + using TA = int; using TC = std::array; - using TAC = std::pair; + using TAC = std::pair; /// @brief Constructor for Exx_LRI_Interface - /// @param exx_ptr - Exx_LRI_Interface(std::shared_ptr> exx_ptr) : exx_ptr(exx_ptr) {} + Exx_LRI_Interface(const Exx_Info::Exx_Info_RI& info) + { + this->exx_ptr = std::make_shared>(info); + } Exx_LRI_Interface() = delete; /// read and write Hexxs using cereal void write_Hexxs_cereal(const std::string& file_name) const; void read_Hexxs_cereal(const std::string& file_name); - std::vector>>>& get_Hexxs() const { return this->exx_ptr->Hexxs; } - - double& get_Eexx() const { return this->exx_ptr->Eexx; } + std::vector>>>& get_Hexxs() const { return this->exx_ptr->Hexxs; } + double &get_Eexx() const { return this->exx_ptr->Eexx; } + ModuleBase::matrix &get_force() const { return this->exx_ptr->force_exx; } + ModuleBase::matrix &get_stress() const { return this->exx_ptr->stress_exx; } + + // Processes in ESolver_KS_LCAO + /// @brief in init: Exx_LRI::init() + void init(const MPI_Comm &mpi_comm, + const UnitCell &ucell, + const K_Vectors &kv, + const LCAO_Orbitals& orb); + + /// @brief: in cal_exx_ions: Exx_LRI::cal_exx_ions() + void cal_exx_ions(const UnitCell& ucell, const bool write_cv = false); + + /// @brief: in cal_exx_elec: Exx_LRI::cal_exx_elec() + void cal_exx_elec(const std::vector>>>& Ds, + const UnitCell& ucell, + const Parallel_Orbitals& pv, + const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr); + + /// @brief: in cal_exx_force: Exx_LRI::cal_exx_force() + void cal_exx_force(const int& nat); + + /// @brief: in cal_exx_stress: Exx_LRI::cal_exx_stress() + void cal_exx_stress(const double& omega, const double& lat0); // Processes in ESolver_KS_LCAO /// @brief in before_all_runners: set symmetry according to irreducible k-points - /// since k-points are not reduced again after the variation of the cell and exx-symmetry must be consistent with k-points. + /// since k-points are not reduced again after the variation of the cell and exx-symmetry must be consistent with k-points. /// In the future, we will reduce k-points again during cell-relax, then this setting can be moved to `exx_beforescf`. void exx_before_all_runners(const K_Vectors& kv, const UnitCell& ucell, const Parallel_2D& pv); @@ -58,23 +84,23 @@ class Exx_LRI_Interface /// @brief in eachiterinit: do DM mixing and calculate Hexx when entering 2nd SCF void exx_eachiterinit(const int istep, - const UnitCell& ucell, + const UnitCell& ucell, const elecstate::DensityMatrix& dm/**< double should be Tdata if complex-PBE-DM is supported*/, - const K_Vectors& kv, + const K_Vectors& kv, const int& iter); /// @brief in hamilt2rho: calculate Hexx and Eexx void exx_hamilt2rho(elecstate::ElecState& elec, const Parallel_Orbitals& pv, const int iter); /// @brief in iter_finish: write Hexx, do something according to whether SCF is converged - void exx_iter_finish(const K_Vectors& kv, + void exx_iter_finish(const K_Vectors& kv, const UnitCell& ucell, - hamilt::Hamilt& hamilt, - elecstate::ElecState& elec, + hamilt::Hamilt& hamilt, + elecstate::ElecState& elec, Charge_Mixing& chgmix, - const double& scf_ene_thr, - int& iter, - const int istep, + const double& scf_ene_thr, + int& iter, + const int istep, bool& conv_esolver); /// @brief: in do_after_converge: add exx operators; do DM mixing if seperate loop bool exx_after_converge(const UnitCell& ucell, @@ -86,15 +112,28 @@ class Exx_LRI_Interface const int& istep, const double& etot, const double& scf_ene_thr); + int two_level_step = 0; double etot_last_outer_loop = 0.0; elecstate::DensityMatrix* dm_last_step; -private: + std::shared_ptr> exx_ptr; + +private: Mix_DMk_2D mix_DMk_2D; bool exx_spacegroup_symmetry = false; ModuleSymmetry::Symmetry_rotation symrot_; + + struct Flag_Finish + { + bool init = false; + bool ions = false; + bool elec = false; + bool force = false; + bool stress = false; + }; + Flag_Finish flag_finish; }; #include "Exx_LRI_interface.hpp" diff --git a/source/module_ri/Exx_LRI_interface.hpp b/source/module_ri/Exx_LRI_interface.hpp index 6226eaef12..8c1db5d513 100644 --- a/source/module_ri/Exx_LRI_interface.hpp +++ b/source/module_ri/Exx_LRI_interface.hpp @@ -10,18 +10,21 @@ #include "module_base/parallel_common.h" #include "module_base/formatter.h" -#include #include "module_io/csr_reader.h" #include "module_io/write_HS_sparse.h" #include "module_elecstate/elecstate_lcao.h" +#include +#include +#include + template void Exx_LRI_Interface::write_Hexxs_cereal(const std::string& file_name) const { ModuleBase::TITLE("Exx_LRI", "write_Hexxs_cereal"); ModuleBase::timer::tick("Exx_LRI", "write_Hexxs_cereal"); std::ofstream ofs(file_name + "_" + std::to_string(GlobalV::MY_RANK), std::ofstream::binary); - cereal::BinaryOutputArchive oar(ofs); + cereal::BinaryOutputArchive oar(ofs); oar(this->exx_ptr->Hexxs); ModuleBase::timer::tick("Exx_LRI", "write_Hexxs_cereal"); } @@ -32,14 +35,81 @@ void Exx_LRI_Interface::read_Hexxs_cereal(const std::string& file_name ModuleBase::TITLE("Exx_LRI", "read_Hexxs_cereal"); ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal"); std::ifstream ifs(file_name + "_" + std::to_string(GlobalV::MY_RANK), std::ofstream::binary); - cereal::BinaryInputArchive iar(ifs); - iar(this->exx_ptr->Hexxs); + cereal::BinaryInputArchive iar(ifs); + iar(this->exx_ptr->Hexxs); ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal"); } +template +void Exx_LRI_Interface::init(const MPI_Comm &mpi_comm, + const UnitCell &ucell, + const K_Vectors &kv, + const LCAO_Orbitals& orb) +{ + ModuleBase::TITLE("Exx_LRI_Interface","init"); + this->exx_ptr->init(mpi_comm, ucell, kv, orb); + this->flag_finish.init = true; +} + +template +void Exx_LRI_Interface::cal_exx_ions(const UnitCell& ucell, const bool write_cv) +{ + ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_ions"); + if(!this->flag_finish.init) + { throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); } + + this->exx_ptr->cal_exx_ions(ucell, write_cv); + + this->flag_finish.ions = true; +} + +template +void Exx_LRI_Interface::cal_exx_elec(const std::vector>>>& Ds, + const UnitCell& ucell, + const Parallel_Orbitals& pv, + const ModuleSymmetry::Symmetry_rotation* p_symrot) +{ + ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_elec"); + if(!this->flag_finish.init || !this->flag_finish.ions) + { throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); } + + this->exx_ptr->cal_exx_elec(Ds, ucell, pv, p_symrot); + + this->flag_finish.elec = true; +} + +template +void Exx_LRI_Interface::cal_exx_force(const int& nat) +{ + ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_force"); + if(!this->flag_finish.init || !this->flag_finish.ions) + { throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); } + if(!this->flag_finish.elec) + { throw std::runtime_error("Exx Hamiltonian unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); } + + this->exx_ptr->cal_exx_force(nat); + + this->flag_finish.force = true; +} + +template +void Exx_LRI_Interface::cal_exx_stress(const double& omega, const double& lat0) +{ + ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_stress"); + if(!this->flag_finish.init || !this->flag_finish.ions) + { throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); } + if(!this->flag_finish.elec) + { throw std::runtime_error("Exx Hamiltonian unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__)); } + + this->exx_ptr->cal_exx_stress(omega, lat0); + + this->flag_finish.stress = true; +} + template void Exx_LRI_Interface::exx_before_all_runners(const K_Vectors& kv, const UnitCell& ucell, const Parallel_2D& pv) { + ModuleBase::TITLE("Exx_LRI_Interface","exx_before_all_runners"); // initialize the rotation matrix in AO representation this->exx_spacegroup_symmetry = (PARAM.inp.nspin < 4 && ModuleSymmetry::Symmetry::symm_flag == 1); if (this->exx_spacegroup_symmetry) @@ -53,12 +123,13 @@ void Exx_LRI_Interface::exx_before_all_runners(const K_Vectors& kv, co } template -void Exx_LRI_Interface::exx_beforescf(const int istep, - const K_Vectors& kv, - const Charge_Mixing& chgmix, - const UnitCell& ucell, +void Exx_LRI_Interface::exx_beforescf(const int istep, + const K_Vectors& kv, + const Charge_Mixing& chgmix, + const UnitCell& ucell, const LCAO_Orbitals& orb) { + ModuleBase::TITLE("Exx_LRI_Interface","exx_beforescf"); #ifdef __MPI if (GlobalC::exx_info.info_global.cal_exx) { @@ -79,49 +150,51 @@ void Exx_LRI_Interface::exx_beforescf(const int istep, XC_Functional::set_xc_type("scan"); } // added by jghan, 2024-07-07 - else if ( ucell.atoms[0].ncpp.xc_func == "MULLER" || ucell.atoms[0].ncpp.xc_func == "POWER" + else if ( ucell.atoms[0].ncpp.xc_func == "MULLER" || ucell.atoms[0].ncpp.xc_func == "POWER" || ucell.atoms[0].ncpp.xc_func == "WP22" || ucell.atoms[0].ncpp.xc_func == "CWP22" ) { XC_Functional::set_xc_type("pbe"); } } - this->exx_ptr->cal_exx_ions(ucell,PARAM.inp.out_ri_cv); + + this->cal_exx_ions(ucell,PARAM.inp.out_ri_cv); } - if (Exx_Abfs::Jle::generate_matrix) - { - //program should be stopped after this judgement - Exx_Opt_Orb exx_opt_orb; - exx_opt_orb.generate_matrix(kv, ucell,orb); - ModuleBase::timer::tick("ESolver_KS_LCAO", "beforescf"); - return; - } - - // set initial parameter for mix_DMk_2D - if(GlobalC::exx_info.info_global.cal_exx) - { - if (this->exx_spacegroup_symmetry) - {this->mix_DMk_2D.set_nks(kv.get_nkstot_full() * (PARAM.inp.nspin == 2 ? 2 : 1), PARAM.globalv.gamma_only_local);} - else - {this->mix_DMk_2D.set_nks(kv.get_nks(), PARAM.globalv.gamma_only_local);} - if(GlobalC::exx_info.info_global.separate_loop) { - this->mix_DMk_2D.set_mixing(nullptr); - } else { - this->mix_DMk_2D.set_mixing(chgmix.get_mixing()); - } - // for exx two_level scf - this->two_level_step = 0; - } + if (Exx_Abfs::Jle::generate_matrix) + { + //program should be stopped after this judgement + Exx_Opt_Orb exx_opt_orb; + exx_opt_orb.generate_matrix(kv, ucell,orb); + ModuleBase::timer::tick("ESolver_KS_LCAO", "beforescf"); + return; + } + + // set initial parameter for mix_DMk_2D + if(GlobalC::exx_info.info_global.cal_exx) + { + if (this->exx_spacegroup_symmetry) + {this->mix_DMk_2D.set_nks(kv.get_nkstot_full() * (PARAM.inp.nspin == 2 ? 2 : 1), PARAM.globalv.gamma_only_local);} + else + {this->mix_DMk_2D.set_nks(kv.get_nks(), PARAM.globalv.gamma_only_local);} + + if(GlobalC::exx_info.info_global.separate_loop) + { this->mix_DMk_2D.set_mixing(nullptr); } + else + { this->mix_DMk_2D.set_mixing(chgmix.get_mixing()); } + // for exx two_level scf + this->two_level_step = 0; + } #endif // __MPI } template -void Exx_LRI_Interface::exx_eachiterinit(const int istep, - const UnitCell& ucell, - const elecstate::DensityMatrix& dm, - const K_Vectors& kv, +void Exx_LRI_Interface::exx_eachiterinit(const int istep, + const UnitCell& ucell, + const elecstate::DensityMatrix& dm, + const K_Vectors& kv, const int& iter) { + ModuleBase::TITLE("Exx_LRI_Interface","exx_eachiterinit"); if (GlobalC::exx_info.info_global.cal_exx) { if (!GlobalC::exx_info.info_global.separate_loop @@ -134,61 +207,48 @@ void Exx_LRI_Interface::exx_eachiterinit(const int istep, && iter == 1) ) // the first iter in separate loop case { - std::cout << " UPDATE EXX" << std::endl; - const bool flag_restart = (iter == 1) ? true : false; auto cal = [this, &ucell,&kv, &flag_restart](const elecstate::DensityMatrix& dm_in) { - if (this->exx_spacegroup_symmetry) - { - this->mix_DMk_2D.mix(symrot_.restore_dm(kv,dm_in.get_DMK_vector(), - *dm_in.get_paraV_pointer()), - flag_restart); - } - else - { - this->mix_DMk_2D.mix(dm_in.get_DMK_vector(), flag_restart); - } - - const std::vector>,RI::Tensor>>> - Ds = PARAM.globalv.gamma_only_local - ? RI_2D_Comm::split_m2D_ktoR(ucell, - *this->exx_ptr->p_kv, - this->mix_DMk_2D.get_DMk_gamma_out(), - *dm_in.get_paraV_pointer(), + if (this->exx_spacegroup_symmetry) + { this->mix_DMk_2D.mix(symrot_.restore_dm(kv,dm_in.get_DMK_vector(), *dm_in.get_paraV_pointer()), flag_restart); } + else + { this->mix_DMk_2D.mix(dm_in.get_DMK_vector(), flag_restart); } + const std::vector>>> + Ds = PARAM.globalv.gamma_only_local + ? RI_2D_Comm::split_m2D_ktoR( + ucell, + *this->exx_ptr->p_kv, + this->mix_DMk_2D.get_DMk_gamma_out(), + *dm_in.get_paraV_pointer(), PARAM.inp.nspin) - : RI_2D_Comm::split_m2D_ktoR(ucell, - *this->exx_ptr->p_kv, - this->mix_DMk_2D.get_DMk_k_out(), - *dm_in.get_paraV_pointer(), - PARAM.inp.nspin, + : RI_2D_Comm::split_m2D_ktoR( + ucell, + *this->exx_ptr->p_kv, + this->mix_DMk_2D.get_DMk_k_out(), + *dm_in.get_paraV_pointer(), + PARAM.inp.nspin, this->exx_spacegroup_symmetry); - if (this->exx_spacegroup_symmetry && GlobalC::exx_info.info_global.exx_symmetry_realspace) - { - this->exx_ptr->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer(), &this->symrot_); - } - else - { - this->exx_ptr->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer()); - } + if (this->exx_spacegroup_symmetry && GlobalC::exx_info.info_global.exx_symmetry_realspace) + { this->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer(), &this->symrot_); } + else + { this->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer()); } }; - if(istep > 0 && flag_restart) - { - cal(*dm_last_step); - } - else - { - cal(dm); - } - } + + if(istep > 0 && flag_restart) + { cal(*dm_last_step); } + else + { cal(dm); } + } } } template void Exx_LRI_Interface::exx_hamilt2rho(elecstate::ElecState& elec, const Parallel_Orbitals& pv, const int iter) { + ModuleBase::TITLE("Exx_LRI_Interface","exx_hamilt2density"); // Peize Lin add 2020.04.04 if (XC_Functional::get_func_type() == 4 || XC_Functional::get_func_type() == 5) { @@ -217,16 +277,17 @@ void Exx_LRI_Interface::exx_hamilt2rho(elecstate::ElecState& elec, con } template -void Exx_LRI_Interface::exx_iter_finish(const K_Vectors& kv, +void Exx_LRI_Interface::exx_iter_finish(const K_Vectors& kv, const UnitCell& ucell, - hamilt::Hamilt& hamilt, - elecstate::ElecState& elec, + hamilt::Hamilt& hamilt, + elecstate::ElecState& elec, Charge_Mixing& chgmix, - const double& scf_ene_thr, - int& iter, - const int istep, + const double& scf_ene_thr, + int& iter, + const int istep, bool& conv_esolver) { + ModuleBase::TITLE("Exx_LRI_Interface","exx_iter_finish"); if (GlobalC::restart.info_save.save_H && (this->two_level_step > 0 || istep > 0) && (!GlobalC::exx_info.info_global.separate_loop || iter == 1)) // to avoid saving the same value repeatedly { @@ -294,124 +355,91 @@ bool Exx_LRI_Interface::exx_after_converge( const double& etot, const double& scf_ene_thr) { // only called if (GlobalC::exx_info.info_global.cal_exx) + ModuleBase::TITLE("Exx_LRI_Interface","exx_after_converge"); auto restart_reset = [this]() - { // avoid calling restart related procedure in the subsequent ion steps - GlobalC::restart.info_load.restart_exx = true; - this->exx_ptr->Eexx = 0; - }; - - // no separate_loop case - if (!GlobalC::exx_info.info_global.separate_loop) - { - GlobalC::exx_info.info_global.hybrid_step = 1; + { // avoid calling restart related procedure in the subsequent ion steps + GlobalC::restart.info_load.restart_exx = true; + this->exx_ptr->Eexx = 0; + }; - // in no_separate_loop case, scf loop only did twice - // in first scf loop, exx updated once in beginning, - // in second scf loop, exx updated every iter + // no separate_loop case + if (!GlobalC::exx_info.info_global.separate_loop) + { + GlobalC::exx_info.info_global.hybrid_step = 1; - if (this->two_level_step || istep > 0) - { - restart_reset(); - return true; - } - else - { - // update exx and redo scf - XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); - iter = 0; - std::cout << " Entering 2nd SCF, where EXX is updated" << std::endl; - this->two_level_step++; - return false; - } + // in no_separate_loop case, scf loop only did twice + // in first scf loop, exx updated once in beginning, + // in second scf loop, exx updated every iter + + if (this->two_level_step || istep > 0) + { + restart_reset(); + return true; } else - { // has separate_loop case - const double ediff = std::abs(etot - etot_last_outer_loop) * ModuleBase::Ry_to_eV; - if (two_level_step) - { - std::cout << FmtCore::format(" deltaE (eV) from outer loop: %.8e \n", ediff); - } - // exx converged or get max exx steps - if (this->two_level_step == GlobalC::exx_info.info_global.hybrid_step - || (iter == 1 && this->two_level_step != 0) // density convergence of outer loop - || (ediff < scf_ene_thr && this->two_level_step != 0)) //energy convergence of outer loop - { - restart_reset(); - return true; - } - else - { - this->etot_last_outer_loop = etot; - // update exx and redo scf - if (this->two_level_step == 0) - { - XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); - } - - std::cout << " Updating EXX " << std::flush; - timeval t_start; gettimeofday(&t_start, nullptr); - - // if init_wfc == "file", DM is calculated in the 1st iter of the 1st two-level step, so we mix it here - const bool flag_restart = (this->two_level_step == 0 && PARAM.inp.init_wfc != "file") ? true : false; + { + // update exx and redo scf + XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); + iter = 0; + std::cout << " Entering 2nd SCF, where EXX is updated" << std::endl; + this->two_level_step++; + return false; + } + } + else + { // has separate_loop case + const double ediff = std::abs(etot - etot_last_outer_loop) * ModuleBase::Ry_to_eV; + if (two_level_step) + { std::cout << FmtCore::format(" deltaE (eV) from outer loop: %.8e \n", ediff); } + // exx converged or get max exx steps + if (this->two_level_step == GlobalC::exx_info.info_global.hybrid_step + || (iter == 1 && this->two_level_step != 0) // density convergence of outer loop + || (ediff < scf_ene_thr && this->two_level_step != 0)) //energy convergence of outer loop + { + restart_reset(); + return true; + } + else + { + this->etot_last_outer_loop = etot; + // update exx and redo scf + if (this->two_level_step == 0) + { XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); } - if (this->exx_spacegroup_symmetry) - {this->mix_DMk_2D.mix(symrot_.restore_dm(kv, dm.get_DMK_vector(), *dm.get_paraV_pointer()), flag_restart);} - else - {this->mix_DMk_2D.mix(dm.get_DMK_vector(), flag_restart);} + std::cout << " Updating EXX " << std::flush; + timeval t_start; gettimeofday(&t_start, nullptr); - // GlobalC::exx_lcao.cal_exx_elec(p_esolver->LOC, p_esolver->LOWF.wfc_k_grid); - const std::vector>, RI::Tensor>>> - Ds = std::is_same::value //gamma_only_local - ? RI_2D_Comm::split_m2D_ktoR(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_gamma_out(), *dm.get_paraV_pointer(), nspin) - : RI_2D_Comm::split_m2D_ktoR(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_k_out(), *dm.get_paraV_pointer(), nspin, this->exx_spacegroup_symmetry); + // if init_wfc == "file", DM is calculated in the 1st iter of the 1st two-level step, so we mix it here + const bool flag_restart = (this->two_level_step == 0 && PARAM.inp.init_wfc != "file") ? true : false; - // check the rotation of Ds - // this->symrot_.test_HR_rotation(ucell.symm, ucell.atoms, ucell.st, 'D', Ds[0]); + if (this->exx_spacegroup_symmetry) + {this->mix_DMk_2D.mix(symrot_.restore_dm(kv, dm.get_DMK_vector(), *dm.get_paraV_pointer()), flag_restart);} + else + {this->mix_DMk_2D.mix(dm.get_DMK_vector(), flag_restart);} - // check the rotation of H(R) before adding exx - // this->symrot_.find_irreducible_sector(ucell.symm, ucell.atoms, ucell.st, this->symrot_.get_Rs_from_adjacent_list(ucell, GlobalC::GridD, *lm.ParaV)); - // this->symrot_.test_HR_rotation(ucell.symm, ucell.atoms, ucell.st, 'H', *(dynamic_cast*>(&hamilt)->getHR())); - // exit(0); + // GlobalC::exx_lcao.cal_exx_elec(p_esolver->LOC, p_esolver->LOWF.wfc_k_grid); + const std::vector>, RI::Tensor>>> + Ds = std::is_same::value //gamma_only_local + ? RI_2D_Comm::split_m2D_ktoR(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_gamma_out(), *dm.get_paraV_pointer(), nspin) + : RI_2D_Comm::split_m2D_ktoR(ucell,*this->exx_ptr->p_kv, this->mix_DMk_2D.get_DMk_k_out(), *dm.get_paraV_pointer(), nspin, this->exx_spacegroup_symmetry); if (this->exx_spacegroup_symmetry && GlobalC::exx_info.info_global.exx_symmetry_realspace) - { - this->exx_ptr->cal_exx_elec(Ds, ucell, *dm.get_paraV_pointer(), &this->symrot_); - // this->symrot_.print_HR(this->exx_ptr->Hexxs[0], "Hexxs_irreducible"); // test - // this->symrot_.print_HR(this->exx_ptr->Hexxs[0], "Hexxs_restored", 1e-10); // test - } + { this->cal_exx_elec(Ds, ucell, *dm.get_paraV_pointer(), &this->symrot_); } else - { - this->exx_ptr->cal_exx_elec(Ds, ucell, *dm.get_paraV_pointer()); // restore DM but not Hexx - // this->symrot_.print_HR(this->exx_ptr->Hexxs[0], "Hexxs_restore-DM-only"); // test - // this->symrot_.print_HR(this->exx_ptr->Hexxs[0], "Hexxs_ref"); // test - } - // ======================== test ======================== - // if (this->two_level_step)exit(0); - // check the rotation of S(R) - // this->symrot_.find_irreducible_sector(ucell.symm, ucell.atoms, ucell.st, this->symrot_.get_Rs_from_adjacent_list(ucell, GlobalC::GridD, *lm.ParaV)); - // this->symrot_.test_HR_rotation(ucell.symm, ucell.atoms, ucell.st, 'H', *(dynamic_cast*>(&hamilt)->getSR())); - - // check the rotation of D(R): no atom pair? - // symrot_.find_irreducible_sector(ucell.symm, ucell.atoms, ucell.st, symrot_.get_Rs_from_adjacent_list(ucell, GlobalC::GridD, *this->DM->get_paraV_pointer())); - // symrot_.test_HR_rotation(ucell.symm, ucell.atoms, ucell.st, 'D', *(this->DM->get_DMR_pointer(0))); - - // check the rotation of Hexx - // this->symrot_.test_HR_rotation(ucell.symm, ucell.atoms, ucell.st, 'H', this->exx_ptr->Hexxs[0]); - // exit(0);// break after test - // ======================== test ======================== - iter = 0; - this->two_level_step++; - - timeval t_end; gettimeofday(&t_end, nullptr); - std::cout << "and rerun SCF\t" - << std::defaultfloat - << std::setprecision(3) << std::setiosflags(std::ios::scientific) - << (double)(t_end.tv_sec-t_start.tv_sec) + (double)(t_end.tv_usec-t_start.tv_usec)/1000000.0 - << std::defaultfloat << " (s)" << std::endl; + { this->cal_exx_elec(Ds, ucell, *dm.get_paraV_pointer()); } // restore DM but not Hexx + iter = 0; + this->two_level_step++; + + timeval t_end; gettimeofday(&t_end, nullptr); + std::cout << "and rerun SCF\t" + << std::setprecision(3) << std::setiosflags(std::ios::scientific) + << (double)(t_end.tv_sec-t_start.tv_sec) + (double)(t_end.tv_usec-t_start.tv_usec)/1000000.0 + << std::defaultfloat << " (s)" << std::endl; return false; - } } + } // if(GlobalC::exx_info.info_global.separate_loop) restart_reset(); return true; } + #endif diff --git a/source/module_ri/Mix_Matrix.cpp b/source/module_ri/Mix_Matrix.cpp index e474511d53..21d5b5e3c2 100644 --- a/source/module_ri/Mix_Matrix.cpp +++ b/source/module_ri/Mix_Matrix.cpp @@ -14,7 +14,6 @@ template<> void Mix_Matrix::mix(const ModuleBase::matrix& data_in, const bool flag_restart) { - ModuleBase::TITLE("Mix_Matrix","mix"); if(separate_loop) { this->mixing = new Base_Mixing::Plain_Mixing(this->mixing_beta); @@ -40,7 +39,6 @@ void Mix_Matrix::mix(const ModuleBase::matrix& data_in, cons template<> void Mix_Matrix::mix(const ModuleBase::ComplexMatrix& data_in, const bool flag_restart) { - ModuleBase::TITLE("Mix_Matrix", "mix"); if (separate_loop) { this->mixing = new Base_Mixing::Plain_Mixing(this->mixing_beta); @@ -67,7 +65,6 @@ void Mix_Matrix::mix(const ModuleBase::ComplexMatrix& template<> void Mix_Matrix>::mix(const std::vector& data_in, const bool flag_restart) { - ModuleBase::TITLE("Mix_Matrix", "mix"); if (separate_loop) { this->mixing = new Base_Mixing::Plain_Mixing(this->mixing_beta); @@ -93,7 +90,6 @@ void Mix_Matrix>::mix(const std::vector& data_in, co template<> void Mix_Matrix>>::mix(const std::vector>& data_in, const bool flag_restart) { - ModuleBase::TITLE("Mix_Matrix", "mix"); if (separate_loop) { this->mixing = new Base_Mixing::Plain_Mixing(this->mixing_beta);