diff --git a/source/Makefile.Objects b/source/Makefile.Objects index d1f07e4b0c..6fd3eeedf9 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -222,6 +222,7 @@ OBJS_ELECSTAT=elecstate.o\ elecstate_pw_sdft.o\ elecstate_pw_cal_tau.o\ elecstate_op.o\ + elecstate_tools.o\ efield.o\ gatefield.o\ potential_new.o\ diff --git a/source/module_elecstate/CMakeLists.txt b/source/module_elecstate/CMakeLists.txt index b9f55f34c2..d2c8d32a48 100644 --- a/source/module_elecstate/CMakeLists.txt +++ b/source/module_elecstate/CMakeLists.txt @@ -4,6 +4,7 @@ list(APPEND objects elecstate_energy.cpp elecstate_exx.cpp elecstate_print.cpp + elecstate_tools.cpp elecstate_pw.cpp elecstate_pw_sdft.cpp elecstate_pw_cal_tau.cpp diff --git a/source/module_elecstate/elecstate.cpp b/source/module_elecstate/elecstate.cpp index 4c06c99213..2d928660a2 100644 --- a/source/module_elecstate/elecstate.cpp +++ b/source/module_elecstate/elecstate.cpp @@ -15,43 +15,6 @@ const double* ElecState::getRho(int spin) const return &(this->charge->rho[spin][0]); } -void ElecState::fixed_weights(const std::vector& ocp_kb, const int& nbands, const double& nelec) -{ - assert(nbands > 0); - assert(nelec > 0.0); - - const double ne_thr = 1.0e-5; - - const int num = this->klist->get_nks() * nbands; - if (num != ocp_kb.size()) - { - ModuleBase::WARNING_QUIT("ElecState::fixed_weights", - "size of occupation array is wrong , please check ocp_set"); - } - - double num_elec = 0.0; - for (int i = 0; i < ocp_kb.size(); ++i) - { - num_elec += ocp_kb[i]; - } - - if (std::abs(num_elec - nelec) > ne_thr) - { - ModuleBase::WARNING_QUIT("ElecState::fixed_weights", - "total number of occupations is wrong , please check ocp_set"); - } - - for (int ik = 0; ik < this->wg.nr; ++ik) - { - for (int ib = 0; ib < this->wg.nc; ++ib) - { - this->wg(ik, ib) = ocp_kb[ik * this->wg.nc + ib]; - } - } - this->skip_weights = true; - - return; -} void ElecState::init_nelec_spin() @@ -64,143 +27,6 @@ void ElecState::init_nelec_spin() } } - -void ElecState::calculate_weights() -{ - ModuleBase::TITLE("ElecState", "calculate_weights"); - if (this->skip_weights) - { - return; - } - - const int nbands = this->ekb.nc; - const int nks = this->ekb.nr; - - if (!Occupy::use_gaussian_broadening && !Occupy::fixed_occupations) - { - if (PARAM.globalv.two_fermi) - { - Occupy::iweights(nks, - this->klist->wk, - nbands, - this->nelec_spin[0], - this->ekb, - this->eferm.ef_up, - this->wg, - 0, - this->klist->isk); - Occupy::iweights(nks, - this->klist->wk, - nbands, - this->nelec_spin[1], - this->ekb, - this->eferm.ef_dw, - this->wg, - 1, - this->klist->isk); - // ef = ( ef_up + ef_dw ) / 2.0_dp need??? mohan add 2012-04-16 - } - else - { - // -1 means don't need to consider spin. - Occupy::iweights(nks, - this->klist->wk, - nbands, - PARAM.inp.nelec, - this->ekb, - this->eferm.ef, - this->wg, - -1, - this->klist->isk); - } - } - else if (Occupy::use_gaussian_broadening) - { - if (PARAM.globalv.two_fermi) - { - double demet_up = 0.0; - double demet_dw = 0.0; - Occupy::gweights(nks, - this->klist->wk, - nbands, - this->nelec_spin[0], - Occupy::gaussian_parameter, - Occupy::gaussian_type, - this->ekb, - this->eferm.ef_up, - demet_up, - this->wg, - 0, - this->klist->isk); - Occupy::gweights(nks, - this->klist->wk, - nbands, - this->nelec_spin[1], - Occupy::gaussian_parameter, - Occupy::gaussian_type, - this->ekb, - this->eferm.ef_dw, - demet_dw, - this->wg, - 1, - this->klist->isk); - this->f_en.demet = demet_up + demet_dw; - } - else - { - // -1 means is no related to spin. - Occupy::gweights(nks, - this->klist->wk, - nbands, - PARAM.inp.nelec, - Occupy::gaussian_parameter, - Occupy::gaussian_type, - this->ekb, - this->eferm.ef, - this->f_en.demet, - this->wg, - -1, - this->klist->isk); - } -#ifdef __MPI - const int npool = GlobalV::KPAR * PARAM.inp.bndpar; - Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, this->f_en.demet); -#endif - } - else if (Occupy::fixed_occupations) - { - ModuleBase::WARNING_QUIT("calculate_weights", "other occupations, not implemented"); - } - - return; -} - - -void ElecState::calEBand() -{ - ModuleBase::TITLE("ElecState", "calEBand"); - // calculate ebands using wg and ekb - double eband = 0.0; -#ifdef _OPENMP -#pragma omp parallel for collapse(2) reduction(+ : eband) -#endif - for (int ik = 0; ik < this->ekb.nr; ++ik) - { - for (int ibnd = 0; ibnd < this->ekb.nc; ibnd++) - { - eband += this->ekb(ik, ibnd) * this->wg(ik, ibnd); - } - } - this->f_en.eband = eband; - -#ifdef __MPI - const int npool = GlobalV::KPAR * PARAM.inp.bndpar; - Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, this->f_en.eband); -#endif - return; -} - - void ElecState::init_scf(const int istep, const UnitCell& ucell, const Parallel_Grid& pgrid, diff --git a/source/module_elecstate/elecstate.h b/source/module_elecstate/elecstate.h index f4954a8d58..3d99ca4910 100644 --- a/source/module_elecstate/elecstate.h +++ b/source/module_elecstate/elecstate.h @@ -72,12 +72,11 @@ class ElecState return; } - // calculate wg from ekb - virtual void calculate_weights(); + // use occupied weights from INPUT and skip calculate_weights // mohan updated on 2024-06-08 - void fixed_weights(const std::vector& ocp_kb, const int& nbands, const double& nelec); + // if nupdown is not 0(TWO_EFERMI case), // nelec_spin will be fixed and weights will be constrained @@ -167,11 +166,11 @@ class ElecState ModuleBase::matrix wg; ///< occupation weight for each k-point and band public: - // calculate ebands for all k points and all occupied bands - void calEBand(); bool skip_weights = false; }; + + } // namespace elecstate #endif diff --git a/source/module_elecstate/elecstate_tools.cpp b/source/module_elecstate/elecstate_tools.cpp new file mode 100644 index 0000000000..446007a296 --- /dev/null +++ b/source/module_elecstate/elecstate_tools.cpp @@ -0,0 +1,182 @@ +#include "elecstate_tools.h" +#include "occupy.h" +namespace elecstate +{ + void calEBand(const ModuleBase::matrix& ekb,const ModuleBase::matrix& wg,fenergy& f_en) + { + ModuleBase::TITLE("ElecState", "calEBand"); + // calculate ebands using wg and ekb + double eband = 0.0; + #ifdef _OPENMP + #pragma omp parallel for collapse(2) reduction(+ : eband) + #endif + for (int ik = 0; ik < ekb.nr; ++ik) + { + for (int ibnd = 0; ibnd < ekb.nc; ibnd++) + { + eband += ekb(ik, ibnd) * wg(ik, ibnd); + } + } + f_en.eband = eband; + + #ifdef __MPI + const int npool = GlobalV::KPAR * PARAM.inp.bndpar; + Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, f_en.eband); + #endif + return; + } + + void calculate_weights(const ModuleBase::matrix& ekb, + ModuleBase::matrix& wg, + const K_Vectors* klist, + efermi& eferm, + fenergy& f_en, + std::vector& nelec_spin, + const bool skip_weights=false) + { + ModuleBase::TITLE("ElecState", "calculate_weights"); + if (skip_weights==true) return; + + const int nbands = ekb.nc; + const int nks = ekb.nr; + if (!(Occupy::use_gaussian_broadening || Occupy::fixed_occupations)) + { + if (PARAM.globalv.two_fermi) + { + Occupy::iweights(nks, + klist->wk, + nbands, + nelec_spin[0], + ekb, + eferm.ef_up, + wg, + 0, + klist->isk); + Occupy::iweights(nks, + klist->wk, + nbands, + nelec_spin[1], + ekb, + eferm.ef_dw, + wg, + 1, + klist->isk); + // ef = ( ef_up + ef_dw ) / 2.0_dp need??? mohan add 2012-04-16 + } + else + { + // -1 means don't need to consider spin. + Occupy::iweights(nks, + klist->wk, + nbands, + PARAM.inp.nelec, + ekb, + eferm.ef, + wg, + -1, + klist->isk); + } + } + else if (Occupy::use_gaussian_broadening) + { + if (PARAM.globalv.two_fermi) + { + double demet_up = 0.0; + double demet_dw = 0.0; + Occupy::gweights(nks, + klist->wk, + nbands, + nelec_spin[0], + Occupy::gaussian_parameter, + Occupy::gaussian_type, + ekb, + eferm.ef_up, + demet_up, + wg, + 0, + klist->isk); + Occupy::gweights(nks, + klist->wk, + nbands, + nelec_spin[1], + Occupy::gaussian_parameter, + Occupy::gaussian_type, + ekb, + eferm.ef_dw, + demet_dw, + wg, + 1, + klist->isk); + f_en.demet = demet_up + demet_dw; + } + else + { + // -1 means is no related to spin. + Occupy::gweights(nks, + klist->wk, + nbands, + PARAM.inp.nelec, + Occupy::gaussian_parameter, + Occupy::gaussian_type, + ekb, + eferm.ef, + f_en.demet, + wg, + -1, + klist->isk); + } + #ifdef __MPI + const int npool = GlobalV::KPAR * PARAM.inp.bndpar; + Parallel_Reduce::reduce_double_allpool(npool, GlobalV::NPROC_IN_POOL, f_en.demet); + #endif + } + else if (Occupy::fixed_occupations) + { + ModuleBase::WARNING_QUIT("calculate_weights", "other occupations, not implemented"); + } + + return; + } + + void fixed_weights(const std::vector& ocp_kb, + const int& nbands, + const double& nelec, + const K_Vectors* klist, + ModuleBase::matrix& wg, + bool& skip_weights) + { + assert(nbands > 0); + assert(nelec > 0.0); + + const double ne_thr = 1.0e-5; + + const int num = klist->get_nks() * nbands; + if (num != ocp_kb.size()) + { + ModuleBase::WARNING_QUIT("ElecState::fixed_weights", + "size of occupation array is wrong , please check ocp_set"); + } + + double num_elec = 0.0; + for (int i = 0; i < ocp_kb.size(); ++i) + { + num_elec += ocp_kb[i]; + } + + if (std::abs(num_elec - nelec) > ne_thr) + { + ModuleBase::WARNING_QUIT("ElecState::fixed_weights", + "total number of occupations is wrong , please check ocp_set"); + } + + for (int ik = 0; ik < wg.nr; ++ik) + { + for (int ib = 0; ib < wg.nc; ++ib) + { + wg(ik, ib) = ocp_kb[ik * wg.nc + ib]; + } + } + skip_weights = true; + + } +} \ No newline at end of file diff --git a/source/module_elecstate/elecstate_tools.h b/source/module_elecstate/elecstate_tools.h new file mode 100644 index 0000000000..884217f778 --- /dev/null +++ b/source/module_elecstate/elecstate_tools.h @@ -0,0 +1,26 @@ +#ifndef ELECSTATE_TOOLS_H +#define ELECSTATE_TOOLS_H +#include "elecstate.h" +#include "module_base/matrix.h" + +namespace elecstate +{ +void calEBand(const ModuleBase::matrix& ekb, const ModuleBase::matrix& wg, fenergy& f_en); + +void calculate_weights(const ModuleBase::matrix& ekb, + ModuleBase::matrix& wg, + const K_Vectors* klist, + efermi& eferm, + fenergy& f_en, + std::vector& nelec_spin, + const bool skip_weights); + +void fixed_weights(const std::vector& ocp_kb, + const int& nbands, + const double& nelec, + const K_Vectors* klist, + ModuleBase::matrix& wg, + bool& skip_weights); +} // namespace elecstate + +#endif \ No newline at end of file diff --git a/source/module_elecstate/test/CMakeLists.txt b/source/module_elecstate/test/CMakeLists.txt index 6a10494d4d..90d7ed0155 100644 --- a/source/module_elecstate/test/CMakeLists.txt +++ b/source/module_elecstate/test/CMakeLists.txt @@ -45,7 +45,7 @@ AddTest( AddTest( TARGET elecstate_base LIBS parameter ${math_libs} base device - SOURCES elecstate_base_test.cpp ../elecstate.cpp ../occupy.cpp ../../module_psi/psi.cpp + SOURCES elecstate_base_test.cpp ../elecstate.cpp ../elecstate_tools.cpp ../occupy.cpp ../../module_psi/psi.cpp ) AddTest( diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index 9a8cd34d66..dfe9aab345 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -1,12 +1,12 @@ -#include - #include "gmock/gmock.h" #include "gtest/gtest.h" +#include #define private public #define protected public -#include "module_parameter/parameter.h" #include "module_elecstate/elecstate.h" +#include "module_elecstate/elecstate_tools.h" #include "module_elecstate/occupy.h" +#include "module_parameter/parameter.h" #undef protected #undef private @@ -32,14 +32,26 @@ Charge::Charge() Charge::~Charge() { } -UnitCell::UnitCell(){} -UnitCell::~UnitCell(){} -Parallel_Grid::Parallel_Grid(){}; -Parallel_Grid::~Parallel_Grid(){}; -Magnetism::Magnetism(){} -Magnetism::~Magnetism(){} -InfoNonlocal::InfoNonlocal(){} -InfoNonlocal::~InfoNonlocal(){} +UnitCell::UnitCell() +{ +} +UnitCell::~UnitCell() +{ +} +Parallel_Grid::Parallel_Grid() {}; +Parallel_Grid::~Parallel_Grid() {}; +Magnetism::Magnetism() +{ +} +Magnetism::~Magnetism() +{ +} +InfoNonlocal::InfoNonlocal() +{ +} +InfoNonlocal::~InfoNonlocal() +{ +} #include "module_cell/klist.h" ModulePW::PW_Basis::PW_Basis() @@ -51,7 +63,7 @@ ModulePW::PW_Basis::~PW_Basis() ModulePW::PW_Basis_Sup::~PW_Basis_Sup() { } -ModulePW::FFT_Bundle::~FFT_Bundle(){}; +ModulePW::FFT_Bundle::~FFT_Bundle() {}; void ModulePW::PW_Basis::initgrids(double, ModuleBase::Matrix3, double) { } @@ -125,7 +137,7 @@ class MockElecState : public ElecState { PARAM.input.nspin = 1; PARAM.input.nelec = 10.0; - PARAM.input.nupdown = 0.0; + PARAM.input.nupdown = 0.0; PARAM.sys.two_fermi = false; PARAM.input.nbands = 6; PARAM.sys.nbands_l = 6; @@ -260,7 +272,7 @@ TEST_F(ElecStateTest, InitSCF) delete charge; } -TEST_F(ElecStateTest,FixedWeights) +TEST_F(ElecStateTest, FixedWeights) { EXPECT_EQ(PARAM.input.nbands, 6); PARAM.input.nelec = 30; @@ -269,18 +281,18 @@ TEST_F(ElecStateTest,FixedWeights) elecstate->klist = klist; elecstate->wg.create(klist->get_nks(), PARAM.input.nbands); std::vector ocp_kb; - ocp_kb.resize(PARAM.input.nbands*elecstate->klist->get_nks()); + ocp_kb.resize(PARAM.input.nbands * elecstate->klist->get_nks()); for (int i = 0; i < ocp_kb.size(); ++i) { ocp_kb[i] = 1.0; } - elecstate->fixed_weights(ocp_kb, PARAM.input.nbands, PARAM.input.nelec); + elecstate::fixed_weights(ocp_kb, PARAM.input.nbands, PARAM.input.nelec,klist,elecstate->wg,elecstate->skip_weights); EXPECT_EQ(elecstate->wg(0, 0), 1.0); - EXPECT_EQ(elecstate->wg(klist->get_nks()-1, PARAM.input.nbands-1), 1.0); + EXPECT_EQ(elecstate->wg(klist->get_nks() - 1, PARAM.input.nbands - 1), 1.0); EXPECT_TRUE(elecstate->skip_weights); } -TEST_F(ElecStateDeathTest,FixedWeightsWarning1) +TEST_F(ElecStateDeathTest, FixedWeightsWarning1) { EXPECT_EQ(PARAM.input.nbands, 6); PARAM.input.nelec = 30; @@ -289,18 +301,20 @@ TEST_F(ElecStateDeathTest,FixedWeightsWarning1) elecstate->klist = klist; elecstate->wg.create(klist->get_nks(), PARAM.input.nbands); std::vector ocp_kb; - ocp_kb.resize(PARAM.input.nbands*elecstate->klist->get_nks()-1); + ocp_kb.resize(PARAM.input.nbands * elecstate->klist->get_nks() - 1); for (int i = 0; i < ocp_kb.size(); ++i) { ocp_kb[i] = 1.0; } testing::internal::CaptureStdout(); - EXPECT_EXIT(elecstate->fixed_weights(ocp_kb, PARAM.input.nbands, PARAM.input.nelec), ::testing::ExitedWithCode(1), ""); + EXPECT_EXIT(elecstate::fixed_weights(ocp_kb, PARAM.input.nbands, PARAM.input.nelec,klist,elecstate->wg,elecstate->skip_weights), + ::testing::ExitedWithCode(1), + ""); output = testing::internal::GetCapturedStdout(); EXPECT_THAT(output, testing::HasSubstr("size of occupation array is wrong , please check ocp_set")); } -TEST_F(ElecStateDeathTest,FixedWeightsWarning2) +TEST_F(ElecStateDeathTest, FixedWeightsWarning2) { EXPECT_EQ(PARAM.input.nbands, 6); PARAM.input.nelec = 29; @@ -309,13 +323,15 @@ TEST_F(ElecStateDeathTest,FixedWeightsWarning2) elecstate->klist = klist; elecstate->wg.create(klist->get_nks(), PARAM.input.nbands); std::vector ocp_kb; - ocp_kb.resize(PARAM.input.nbands*elecstate->klist->get_nks()); + ocp_kb.resize(PARAM.input.nbands * elecstate->klist->get_nks()); for (int i = 0; i < ocp_kb.size(); ++i) { ocp_kb[i] = 1.0; } testing::internal::CaptureStdout(); - EXPECT_EXIT(elecstate->fixed_weights(ocp_kb, PARAM.input.nbands, PARAM.input.nelec), ::testing::ExitedWithCode(1), ""); + EXPECT_EXIT(elecstate::fixed_weights(ocp_kb, PARAM.input.nbands, PARAM.input.nelec,klist,elecstate->wg,elecstate->skip_weights), + ::testing::ExitedWithCode(1), + ""); output = testing::internal::GetCapturedStdout(); EXPECT_THAT(output, testing::HasSubstr("total number of occupations is wrong , please check ocp_set")); } @@ -335,7 +351,7 @@ TEST_F(ElecStateTest, CalEBand) } } GlobalV::KPAR = 2; - elecstate->calEBand(); + elecstate::calEBand(elecstate->ekb, elecstate->wg, elecstate->f_en); EXPECT_DOUBLE_EQ(elecstate->f_en.eband, 60.0); } @@ -343,14 +359,28 @@ TEST_F(ElecStateTest, CalculateWeightsSkipWeights) { EXPECT_FALSE(elecstate->skip_weights); elecstate->skip_weights = true; - EXPECT_NO_THROW(elecstate->calculate_weights()); + EXPECT_NO_THROW(elecstate::calculate_weights(elecstate->ekb, + elecstate->wg, + elecstate->klist, + elecstate->eferm, + elecstate->f_en, + elecstate->nelec_spin, + elecstate->skip_weights)); } TEST_F(ElecStateDeathTest, CalculateWeightsFixedOccupations) { Occupy::fixed_occupations = true; testing::internal::CaptureStdout(); - EXPECT_EXIT(elecstate->calculate_weights(), ::testing::ExitedWithCode(1), ""); + EXPECT_EXIT(elecstate::calculate_weights(elecstate->ekb, + elecstate->wg, + elecstate->klist, + elecstate->eferm, + elecstate->f_en, + elecstate->nelec_spin, + elecstate->skip_weights), + ::testing::ExitedWithCode(1), + ""); output = testing::internal::GetCapturedStdout(); EXPECT_THAT(output, testing::HasSubstr("other occupations, not implemented")); Occupy::fixed_occupations = false; @@ -383,10 +413,16 @@ TEST_F(ElecStateTest, CalculateWeightsIWeights) } } elecstate->wg.create(nks, PARAM.input.nbands); - elecstate->calculate_weights(); + elecstate::calculate_weights(elecstate->ekb, + elecstate->wg, + elecstate->klist, + elecstate->eferm, + elecstate->f_en, + elecstate->nelec_spin, + elecstate->skip_weights); EXPECT_DOUBLE_EQ(elecstate->wg(0, 0), 2.0); - EXPECT_DOUBLE_EQ(elecstate->wg(nks-1, PARAM.input.nelec/2-1), 2.0); - EXPECT_DOUBLE_EQ(elecstate->wg(nks-1, PARAM.input.nbands-1), 0.0); + EXPECT_DOUBLE_EQ(elecstate->wg(nks - 1, PARAM.input.nelec / 2 - 1), 2.0); + EXPECT_DOUBLE_EQ(elecstate->wg(nks - 1, PARAM.input.nbands - 1), 0.0); EXPECT_DOUBLE_EQ(elecstate->eferm.ef, 100.0); delete klist; } @@ -401,13 +437,13 @@ TEST_F(ElecStateTest, CalculateWeightsIWeightsTwoFermi) EXPECT_EQ(elecstate->nelec_spin[1], 5.0); // EXPECT_FALSE(elecstate->skip_weights); - int nks = 5*PARAM.input.nspin; + int nks = 5 * PARAM.input.nspin; K_Vectors* klist = new K_Vectors; klist->set_nks(nks); klist->wk.resize(nks); for (int ik = 0; ik < nks; ++ik) { - if(ik<5) + if (ik < 5) { klist->wk[ik] = 1.1; } @@ -419,7 +455,7 @@ TEST_F(ElecStateTest, CalculateWeightsIWeightsTwoFermi) klist->isk.resize(nks); for (int ik = 0; ik < nks; ++ik) { - if(ik < 5) + if (ik < 5) { klist->isk[ik] = 0; } @@ -436,7 +472,7 @@ TEST_F(ElecStateTest, CalculateWeightsIWeightsTwoFermi) { for (int ib = 0; ib < PARAM.input.nbands; ++ib) { - if(ik < 5) + if (ik < 5) { elecstate->ekb(ik, ib) = 100.0; } @@ -447,10 +483,16 @@ TEST_F(ElecStateTest, CalculateWeightsIWeightsTwoFermi) } } elecstate->wg.create(nks, PARAM.input.nbands); - elecstate->calculate_weights(); + elecstate::calculate_weights(elecstate->ekb, + elecstate->wg, + elecstate->klist, + elecstate->eferm, + elecstate->f_en, + elecstate->nelec_spin, + elecstate->skip_weights); EXPECT_DOUBLE_EQ(elecstate->wg(0, 0), 1.1); - EXPECT_DOUBLE_EQ(elecstate->wg(nks-1, PARAM.input.nelec/2-1), 1.0); - EXPECT_DOUBLE_EQ(elecstate->wg(nks-1, PARAM.input.nbands-1), 0.0); + EXPECT_DOUBLE_EQ(elecstate->wg(nks - 1, PARAM.input.nelec / 2 - 1), 1.0); + EXPECT_DOUBLE_EQ(elecstate->wg(nks - 1, PARAM.input.nbands - 1), 0.0); EXPECT_DOUBLE_EQ(elecstate->eferm.ef_up, 100.0); EXPECT_DOUBLE_EQ(elecstate->eferm.ef_dw, 200.0); delete klist; @@ -484,14 +526,20 @@ TEST_F(ElecStateTest, CalculateWeightsGWeights) } } elecstate->wg.create(nks, PARAM.input.nbands); - elecstate->calculate_weights(); + elecstate::calculate_weights(elecstate->ekb, + elecstate->wg, + elecstate->klist, + elecstate->eferm, + elecstate->f_en, + elecstate->nelec_spin, + elecstate->skip_weights); // PARAM.input.nelec = 10; // PARAM.input.nbands = 6; // nks = 5; // wg = 10/(5*6) = 0.33333333333 EXPECT_NEAR(elecstate->wg(0, 0), 0.33333333333, 1e-10); - EXPECT_NEAR(elecstate->wg(nks-1, PARAM.input.nelec/2-1), 0.33333333333, 1e-10); - EXPECT_NEAR(elecstate->wg(nks-1, PARAM.input.nbands-1), 0.33333333333,1e-10); + EXPECT_NEAR(elecstate->wg(nks - 1, PARAM.input.nelec / 2 - 1), 0.33333333333, 1e-10); + EXPECT_NEAR(elecstate->wg(nks - 1, PARAM.input.nbands - 1), 0.33333333333, 1e-10); EXPECT_NEAR(elecstate->eferm.ef, 99.993159296503, 1e-10); delete klist; Occupy::use_gaussian_broadening = false; @@ -508,13 +556,13 @@ TEST_F(ElecStateTest, CalculateWeightsGWeightsTwoFermi) EXPECT_EQ(elecstate->nelec_spin[1], 5.0); // EXPECT_FALSE(elecstate->skip_weights); - int nks = 5*PARAM.input.nspin; + int nks = 5 * PARAM.input.nspin; K_Vectors* klist = new K_Vectors; klist->set_nks(nks); klist->wk.resize(nks); for (int ik = 0; ik < nks; ++ik) { - if(ik<5) + if (ik < 5) { klist->wk[ik] = 1.1; } @@ -526,7 +574,7 @@ TEST_F(ElecStateTest, CalculateWeightsGWeightsTwoFermi) klist->isk.resize(nks); for (int ik = 0; ik < nks; ++ik) { - if(ik < 5) + if (ik < 5) { klist->isk[ik] = 0; } @@ -543,7 +591,7 @@ TEST_F(ElecStateTest, CalculateWeightsGWeightsTwoFermi) { for (int ib = 0; ib < PARAM.input.nbands; ++ib) { - if(ik < 5) + if (ik < 5) { elecstate->ekb(ik, ib) = 100.0; } @@ -554,14 +602,20 @@ TEST_F(ElecStateTest, CalculateWeightsGWeightsTwoFermi) } } elecstate->wg.create(nks, PARAM.input.nbands); - elecstate->calculate_weights(); + elecstate::calculate_weights(elecstate->ekb, + elecstate->wg, + elecstate->klist, + elecstate->eferm, + elecstate->f_en, + elecstate->nelec_spin, + elecstate->skip_weights); // PARAM.input.nelec = 10; // PARAM.input.nbands = 6; // nks = 10; // wg = 10/(10*6) = 0.16666666666 EXPECT_NEAR(elecstate->wg(0, 0), 0.16666666666, 1e-10); - EXPECT_NEAR(elecstate->wg(nks-1, PARAM.input.nelec/2-1), 0.16666666666, 1e-10); - EXPECT_NEAR(elecstate->wg(nks-1, PARAM.input.nbands-1), 0.16666666666, 1e-10); + EXPECT_NEAR(elecstate->wg(nks - 1, PARAM.input.nelec / 2 - 1), 0.16666666666, 1e-10); + EXPECT_NEAR(elecstate->wg(nks - 1, PARAM.input.nbands - 1), 0.16666666666, 1e-10); EXPECT_NEAR(elecstate->eferm.ef_up, 99.992717105890961, 1e-10); EXPECT_NEAR(elecstate->eferm.ef_dw, 199.99315929650351, 1e-10); delete klist; diff --git a/source/module_elecstate/test/elecstate_energy_test.cpp b/source/module_elecstate/test/elecstate_energy_test.cpp index b2040447ac..782ccfb3a5 100644 --- a/source/module_elecstate/test/elecstate_energy_test.cpp +++ b/source/module_elecstate/test/elecstate_energy_test.cpp @@ -89,10 +89,6 @@ const double* ElecState::getRho(int spin) const { return &(this->eferm.ef); } // just for mock -void ElecState::calculate_weights() -{ - return; -} // just for mock } // namespace elecstate class ElecStateEnergyTest : public ::testing::Test diff --git a/source/module_elecstate/test/elecstate_print_test.cpp b/source/module_elecstate/test/elecstate_print_test.cpp index 226e84c176..4904188cf0 100644 --- a/source/module_elecstate/test/elecstate_print_test.cpp +++ b/source/module_elecstate/test/elecstate_print_test.cpp @@ -21,10 +21,6 @@ const double* ElecState::getRho(int spin) const { return &(this->eferm.ef); } // just for mock -void ElecState::calculate_weights() -{ - return; -} // just for mock double Efield::etotefield = 1.1; double elecstate::Gatefield::etotgatefield = 2.2; diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index bab50b9531..35c2eea9b1 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -27,7 +27,7 @@ #include "module_io/write_proj_band_lcao.h" #include "module_io/write_wfc_nao.h" #include "module_parameter/parameter.h" - +#include "module_elecstate/elecstate_tools.h" //be careful of hpp, there may be multiple definitions of functions, 20250302, mohan #include "module_io/write_eband_terms.hpp" @@ -216,7 +216,12 @@ void ESolver_KS_LCAO::before_all_runners(UnitCell& ucell, const Input_pa // tddft does not need to set occupations in the first scf if (PARAM.inp.ocp && inp.esolver_type != "tddft") { - this->pelec->fixed_weights(PARAM.inp.ocp_kb, PARAM.inp.nbands, PARAM.inp.nelec); + elecstate::fixed_weights(PARAM.inp.ocp_kb, + PARAM.inp.nbands, + PARAM.inp.nelec, + this->pelec->klist, + this->pelec->wg, + this->pelec->skip_weights); } // 12) if kpar is not divisible by nks, print a warning @@ -573,11 +578,18 @@ void ESolver_KS_LCAO::iter_init(UnitCell& ucell, const int istep, const // and then calculate the charge density on grid. this->pelec->skip_weights = true; - this->pelec->calculate_weights(); + elecstate::calculate_weights(this->pelec->ekb, + this->pelec->wg, + this->pelec->klist, + this->pelec->eferm, + this->pelec->f_en, + this->pelec->nelec_spin, + this->pelec->skip_weights); + if (!PARAM.inp.dm_to_rho) { auto _pelec = dynamic_cast*>(this->pelec); - _pelec->calEBand(); + elecstate::calEBand(_pelec->ekb,_pelec->wg,_pelec->f_en); elecstate::cal_dm_psi(_pelec->DM->get_paraV_pointer(), _pelec->wg, *this->psi, *(_pelec->DM)); _pelec->DM->cal_DMR(); } diff --git a/source/module_esolver/esolver_ks_lcao_tddft.cpp b/source/module_esolver/esolver_ks_lcao_tddft.cpp index c1ec58ba6c..5628f5463d 100644 --- a/source/module_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/module_esolver/esolver_ks_lcao_tddft.cpp @@ -6,6 +6,7 @@ #include "module_io/write_HS.h" #include "module_io/write_HS_R.h" #include "module_io/write_wfc_nao.h" +#include "module_elecstate/elecstate_tools.h" //--------------temporary---------------------------- #include "module_base/blas_connector.h" @@ -401,11 +402,16 @@ void ESolver_KS_LCAO_TDDFT::weight_dm_rho() { if (PARAM.inp.ocp == 1) { - this->pelec->fixed_weights(PARAM.inp.ocp_kb, PARAM.inp.nbands, PARAM.inp.nelec); + elecstate::fixed_weights(PARAM.inp.ocp_kb, + PARAM.inp.nbands, + PARAM.inp.nelec, + this->pelec->klist, + this->pelec->wg, + this->pelec->skip_weights); } // calculate Eband energy - this->pelec->calEBand(); + elecstate::calEBand(this->pelec->ekb,this->pelec->wg,this->pelec->f_en); // calculate the density matrix ModuleBase::GlobalFunc::NOTE("Calculate the density matrix."); diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 7c7a7a1b41..67b1be64ab 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -7,6 +7,7 @@ #include "module_elecstate/cal_ux.h" #include "module_elecstate/elecstate_pw.h" #include "module_elecstate/elecstate_pw_sdft.h" +#include "module_elecstate/elecstate_tools.h" #include "module_elecstate/module_charge/symmetry_rho.h" #include "module_hamilt_general/module_ewald/H_Ewald_pw.h" #include "module_hamilt_general/module_vdw/vdw.h" @@ -217,7 +218,12 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p //! 9) setup occupations if (PARAM.inp.ocp) { - this->pelec->fixed_weights(PARAM.inp.ocp_kb, PARAM.globalv.nbands_l, PARAM.inp.nelec); + elecstate::fixed_weights(PARAM.inp.ocp_kb, + PARAM.inp.nbands, + PARAM.inp.nelec, + this->pelec->klist, + this->pelec->wg, + this->pelec->skip_weights); } } diff --git a/source/module_esolver/lcao_before_scf.cpp b/source/module_esolver/lcao_before_scf.cpp index e2ee0da116..bc471cd771 100644 --- a/source/module_esolver/lcao_before_scf.cpp +++ b/source/module_esolver/lcao_before_scf.cpp @@ -15,6 +15,7 @@ #include "module_io/to_wannier90_lcao_in_pw.h" #include "module_io/write_HS_R.h" #include "module_parameter/parameter.h" +#include "module_elecstate/elecstate_tools.h" #ifdef __DEEPKS #include "module_hamilt_lcao/module_deepks/LCAO_deepks.h" #endif @@ -312,10 +313,14 @@ void ESolver_KS_LCAO::before_scf(UnitCell& ucell, const int istep) ModuleIO::read_mat_npz(&(this->pv), ucell, zipname, *(dm->get_DMR_pointer(2))); } - // calculate weights - this->pelec->calculate_weights(); - - // use psi to calculate charge density + elecstate::calculate_weights(this->pelec->ekb, + this->pelec->wg, + this->pelec->klist, + this->pelec->eferm, + this->pelec->f_en, + this->pelec->nelec_spin, + this->pelec->skip_weights); + this->pelec->psiToRho(*this->psi); int nspin0 = PARAM.inp.nspin == 2 ? 2 : 1; diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp index 5167605343..6e07134337 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp @@ -9,9 +9,11 @@ #include "module_hsolver/hsolver_lcao.h" #include "module_hsolver/hsolver_pw.h" #include "module_elecstate/elecstate_pw.h" +#include "module_elecstate/elecstate_tools.h" #ifdef __LCAO #include "module_elecstate/elecstate_lcao.h" +#include "module_elecstate/elecstate_tools.h" #include "module_elecstate/module_dm/cal_dm_psi.h" #include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/dspin_lcao.h" #endif @@ -150,8 +152,14 @@ void spinconstrain::SpinConstrain>::cal_mw_from_lambda(int } // diagonalization without update charge hsolver_t.solve(hamilt_t, psi_t[0], this->pelec, true); - this->pelec->calculate_weights(); - this->pelec->calEBand(); + elecstate::calculate_weights(this->pelec->ekb, + this->pelec->wg, + this->pelec->klist, + this->pelec->eferm, + this->pelec->f_en, + this->pelec->nelec_spin, + this->pelec->skip_weights); + elecstate::calEBand(this->pelec->ekb,this->pelec->wg,this->pelec->f_en); elecstate::ElecStateLCAO>* pelec_lcao = dynamic_cast>*>(this->pelec); elecstate::cal_dm_psi(this->ParaV, pelec_lcao->wg, *psi_t, *(pelec_lcao->get_DM())); @@ -307,7 +315,13 @@ void spinconstrain::SpinConstrain>::cal_mw_from_lambda(int } #endif // calculate weights from ekb to update wg - this->pelec->calculate_weights(); + elecstate::calculate_weights(this->pelec->ekb, + this->pelec->wg, + this->pelec->klist, + this->pelec->eferm, + this->pelec->f_en, + this->pelec->nelec_spin, + this->pelec->skip_weights); // calculate Mi from existed becp for (int ik = 0; ik < nk; ik++) { diff --git a/source/module_hsolver/hsolver_lcao.cpp b/source/module_hsolver/hsolver_lcao.cpp index 4c8ff835de..1e0cfa6acc 100644 --- a/source/module_hsolver/hsolver_lcao.cpp +++ b/source/module_hsolver/hsolver_lcao.cpp @@ -25,6 +25,7 @@ #endif #include "module_base/global_variable.h" +#include "module_elecstate/elecstate_tools.h" #include "module_base/memory.h" #include "module_base/timer.h" #include "module_elecstate/elecstate_lcao.h" @@ -75,13 +76,19 @@ void HSolverLCAO::solve(hamilt::Hamilt* pHamilt, "This method and KPAR setting is not supported for lcao basis in ABACUS!"); } - pes->calculate_weights(); + elecstate::calculate_weights(pes->ekb, + pes->wg, + pes->klist, + pes->eferm, + pes->f_en, + pes->nelec_spin, + pes->skip_weights); if (!PARAM.inp.dm_to_rho) { - auto _pes = dynamic_cast*>(pes); - _pes->calEBand(); - elecstate::cal_dm_psi(_pes->DM->get_paraV_pointer(), _pes->wg, psi, *(_pes->DM)); - _pes->DM->cal_DMR(); + auto _pes_lcao = dynamic_cast*>(pes); + elecstate::calEBand(_pes_lcao->ekb,_pes_lcao->wg,_pes_lcao->f_en); + elecstate::cal_dm_psi(_pes_lcao->DM->get_paraV_pointer(), _pes_lcao->wg, psi, *(_pes_lcao->DM)); + _pes_lcao->DM->cal_DMR(); } if (!skip_charge) diff --git a/source/module_hsolver/hsolver_lcaopw.cpp b/source/module_hsolver/hsolver_lcaopw.cpp index b6e95b4c03..260c94203c 100644 --- a/source/module_hsolver/hsolver_lcaopw.cpp +++ b/source/module_hsolver/hsolver_lcaopw.cpp @@ -9,6 +9,7 @@ #include "module_hamilt_pw/hamilt_pwdft/hamilt_pw.h" #include "module_hsolver/diago_iter_assist.h" #include "module_parameter/parameter.h" +#include "module_elecstate/elecstate_tools.h" #ifdef USE_PAW #include "module_cell/module_paw/paw_cell.h" @@ -274,8 +275,14 @@ void HSolverLIP::solve(hamilt::Hamilt* pHamilt, // ESolver_KS_PW::p_hamilt eigenvalues.data(), pes->ekb.nr * pes->ekb.nc); - reinterpret_cast*>(pes)->calculate_weights(); - reinterpret_cast*>(pes)->calEBand(); + elecstate::calculate_weights(pes->ekb, + pes->wg, + pes->klist, + pes->eferm, + pes->f_en, + pes->nelec_spin, + pes->skip_weights); + elecstate::calEBand(pes->ekb,pes->wg,pes->f_en); if (skip_charge) { if (PARAM.globalv.use_uspp) diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 3206a85672..62bff8a915 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -13,6 +13,8 @@ #include "module_hsolver/diago_iter_assist.h" #include "module_parameter/parameter.h" #include "module_psi/psi.h" +#include "module_elecstate/elecstate_tools.h" + #include #include @@ -335,8 +337,15 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, // pes->ekb.nr * pes->ekb.nc this->wfc_basis->nks * psi.get_nbands()); - reinterpret_cast*>(pes)->calculate_weights(); - reinterpret_cast*>(pes)->calEBand(); + auto _pes_pw = reinterpret_cast*>(pes); + elecstate::calculate_weights(_pes_pw->ekb, + _pes_pw->wg, + _pes_pw->klist, + _pes_pw->eferm, + _pes_pw->f_en, + _pes_pw->nelec_spin, + _pes_pw->skip_weights); + elecstate::calEBand(_pes_pw->ekb,_pes_pw->wg,_pes_pw->f_en); if (skip_charge) { if (PARAM.globalv.use_uspp) diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index b2df935ad6..3408284777 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -5,6 +5,7 @@ #include "module_base/timer.h" #include "module_base/tool_title.h" #include "module_elecstate/module_charge/symmetry_rho.h" +#include "module_elecstate/elecstate_tools.h" #include @@ -89,7 +90,7 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, // calculate eband = \sum_{ik,ib} w(ik)f(ik,ib)e_{ikib}, demet = -TS elecstate::ElecStatePW* pes_pw = static_cast*>(pes); - pes_pw->calEBand(); + elecstate::calEBand(pes_pw->ekb,pes_pw->wg,pes_pw->f_en); if(!PARAM.globalv.all_ks_run) { pes->f_en.eband /= PARAM.inp.bndpar; diff --git a/source/module_hsolver/test/CMakeLists.txt b/source/module_hsolver/test/CMakeLists.txt index 7165e895a7..3150eab9b8 100644 --- a/source/module_hsolver/test/CMakeLists.txt +++ b/source/module_hsolver/test/CMakeLists.txt @@ -78,13 +78,15 @@ if (ENABLE_MPI) TARGET HSolver_pw LIBS parameter ${math_libs} psi device base container SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp + ../../module_elecstate/elecstate_tools.cpp ../../module_elecstate/occupy.cpp ) AddTest( TARGET HSolver_sdft LIBS parameter ${math_libs} psi device base container SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp - ) + ../../module_elecstate/elecstate_tools.cpp ../../module_elecstate/occupy.cpp + ) if(ENABLE_LCAO) if(USE_ELPA) diff --git a/source/module_hsolver/test/hsolver_supplementary_mock.h b/source/module_hsolver/test/hsolver_supplementary_mock.h index c7ead79fff..1529202dab 100644 --- a/source/module_hsolver/test/hsolver_supplementary_mock.h +++ b/source/module_hsolver/test/hsolver_supplementary_mock.h @@ -11,25 +11,13 @@ const double* ElecState::getRho(int spin) const return &(this->charge->rho[spin][0]); } -void ElecState::fixed_weights(const std::vector& ocp_kb, const int& nbands, const double& nelec) -{ - return; -} void ElecState::init_nelec_spin() { return; } -void ElecState::calculate_weights() -{ - return; -} -void ElecState::calEBand() -{ - return; -} void ElecState::init_scf(const int istep, const UnitCell& ucell, diff --git a/source/module_hsolver/test/test_hsolver_pw.cpp b/source/module_hsolver/test/test_hsolver_pw.cpp index 530491a669..6b0e8c5a80 100644 --- a/source/module_hsolver/test/test_hsolver_pw.cpp +++ b/source/module_hsolver/test/test_hsolver_pw.cpp @@ -267,6 +267,9 @@ TEST_F(TestHSolverPW, SolveLcaoInPW) { pwbk.nks = 1; // initial memory and data elecstate_test.ekb.create(1, 2); + elecstate_test.wg.create(1,2); + elecstate_test.klist=new K_Vectors; + elecstate_test.skip_weights=true; elecstate_test.pot = new elecstate::Potential; // 1 kpt, 2 bands, 3 basis psi_test_cf.resize(1, 2, 3); @@ -300,7 +303,7 @@ TEST_F(TestHSolverPW, SolveLcaoInPW) { // check solve() elecstate_test.ekb.c[0] = 1.0; elecstate_test.ekb.c[1] = 2.0; - + hsolver::HSolverLIP> hs_f_lip = hsolver::HSolverLIP>(&pwbk); hsolver::HSolverLIP> hs_d_lip diff --git a/source/module_io/test/read_wfc_nao_test.cpp b/source/module_io/test/read_wfc_nao_test.cpp index 3c50fb178a..4a98a187d1 100644 --- a/source/module_io/test/read_wfc_nao_test.cpp +++ b/source/module_io/test/read_wfc_nao_test.cpp @@ -12,7 +12,6 @@ namespace elecstate { const double* ElecState::getRho(int spin) const{return &(this->eferm.ef);}//just for mock - void ElecState::calculate_weights(){} } // mock wfc_lcao_gen_fname diff --git a/source/module_rdmft/update_state_rdmft.cpp b/source/module_rdmft/update_state_rdmft.cpp index ed08ba0be9..2a22b18864 100644 --- a/source/module_rdmft/update_state_rdmft.cpp +++ b/source/module_rdmft/update_state_rdmft.cpp @@ -155,8 +155,6 @@ void RDMFT::update_charge(UnitCell& ucell) } // charge density symmetrization - // this->pelec->calculate_weights(); - // this->pelec->calEBand(); Symmetry_rho srho; for (int is = 0; is < nspin; is++) {