Skip to content

Commit c03c879

Browse files
authored
Refactor: runner() of esolver_ks (#5445)
* Refactor: runner() of esolver_ks * rename hamilt2density and diag as hamilt2density_single and hamilt2density
1 parent 32ba8d4 commit c03c879

12 files changed

+332
-403
lines changed

source/module_esolver/esolver_ks.cpp

Lines changed: 180 additions & 144 deletions
Large diffs are not rendered by default.

source/module_esolver/esolver_ks.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@ class ESolver_KS : public ESolver_FP
4646
//! Something to do after hamilt2density function in each iter loop.
4747
virtual void iter_finish(const int istep, int& iter);
4848

49-
// calculate electron density from a specific Hamiltonian
50-
virtual void hamilt2density(const int istep, const int iter, const double ethr);
49+
// calculate electron density from a specific Hamiltonian with ethr
50+
virtual void hamilt2density_single(const int istep, const int iter, const double ethr);
5151

5252
// calculate electron states from a specific Hamiltonian
5353
virtual void hamilt2estates(const double ethr) {};
5454

55+
// calculate electron density from a specific Hamiltonian
56+
void hamilt2density(const int istep, const int iter, const double ethr);
57+
5558
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
5659
virtual void after_scf(const int istep) override;
5760

@@ -82,6 +85,7 @@ class ESolver_KS : public ESolver_FP
8285
double scf_thr; // scf density threshold
8386
double scf_ene_thr; // scf energy threshold
8487
double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver)
88+
double hsolver_error; // the error of HSolver
8589
int maxniter; // maximum iter steps for scf
8690
int niter; // iter steps actually used in scf
8791
int out_freq_elec; // frequency for output

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 52 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -681,10 +681,17 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
681681
SpinConstrain<TK, base_device::DEVICE_CPU>& sc = SpinConstrain<TK, base_device::DEVICE_CPU>::getScInstance();
682682
sc.run_lambda_loop(iter - 1);
683683
}
684+
685+
// save density matrix DMR for mixing
686+
if (PARAM.inp.mixing_restart > 0 && PARAM.inp.mixing_dmr && this->p_chgmix->mixing_restart_count > 0)
687+
{
688+
elecstate::DensityMatrix<TK, double>* dm = dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM();
689+
dm->save_DMR();
690+
}
684691
}
685692

686693
//------------------------------------------------------------------------------
687-
//! the 11th function of ESolver_KS_LCAO: hamilt2density
694+
//! the 11th function of ESolver_KS_LCAO: hamilt2density_single
688695
//! mohan add 2024-05-11
689696
//! 1) save input rho
690697
//! 2) save density matrix DMR for mixing
@@ -700,47 +707,16 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
700707
//! 12) calculate delta energy
701708
//------------------------------------------------------------------------------
702709
template <typename TK, typename TR>
703-
void ESolver_KS_LCAO<TK, TR>::hamilt2density(int istep, int iter, double ethr)
710+
void ESolver_KS_LCAO<TK, TR>::hamilt2density_single(int istep, int iter, double ethr)
704711
{
705-
ModuleBase::TITLE("ESolver_KS_LCAO", "hamilt2density");
706-
707-
// 1) save input rho
708-
this->pelec->charge->save_rho_before_sum_band();
709-
710-
// 2) save density matrix DMR for mixing
711-
if (PARAM.inp.mixing_restart > 0 && PARAM.inp.mixing_dmr && this->p_chgmix->mixing_restart_count > 0)
712-
{
713-
elecstate::DensityMatrix<TK, double>* dm = dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM();
714-
dm->save_DMR();
715-
}
712+
ModuleBase::TITLE("ESolver_KS_LCAO", "hamilt2density_single");
716713

717-
// 3) solve the Hamiltonian and output band gap
718-
{
719-
// reset energy
720-
this->pelec->f_en.eband = 0.0;
721-
this->pelec->f_en.demet = 0.0;
722-
723-
hsolver::HSolverLCAO<TK> hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver);
724-
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, false);
714+
// reset energy
715+
this->pelec->f_en.eband = 0.0;
716+
this->pelec->f_en.demet = 0.0;
725717

726-
if (PARAM.inp.out_bandgap)
727-
{
728-
if (!PARAM.globalv.two_fermi)
729-
{
730-
this->pelec->cal_bandgap();
731-
}
732-
else
733-
{
734-
this->pelec->cal_bandgap_updw();
735-
}
736-
}
737-
}
738-
739-
// 4) print bands for each k-point and each band
740-
for (int ik = 0; ik < this->kv.get_nks(); ++ik)
741-
{
742-
this->pelec->print_band(ik, PARAM.inp.printe, iter);
743-
}
718+
hsolver::HSolverLCAO<TK> hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver);
719+
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, false);
744720

745721
// 5) what's the exd used for?
746722
#ifdef __EXX
@@ -754,59 +730,13 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density(int istep, int iter, double ethr)
754730
}
755731
#endif
756732

757-
// 6) calculate the local occupation number matrix and energy correction in
758-
// DFT+U
759-
if (PARAM.inp.dft_plus_u)
760-
{
761-
// only old DFT+U method should calculated energy correction in esolver,
762-
// new DFT+U method will calculate energy in calculating Hamiltonian
763-
if (PARAM.inp.dft_plus_u == 2)
764-
{
765-
if (GlobalC::dftu.omc != 2)
766-
{
767-
const std::vector<std::vector<TK>>& tmp_dm
768-
= dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM()->get_DMK_vector();
769-
this->dftu_cal_occup_m(iter, tmp_dm);
770-
}
771-
GlobalC::dftu.cal_energy_correction(istep);
772-
}
773-
GlobalC::dftu.output();
774-
}
775-
776-
// (7) for deepks, calculate delta_e
777-
#ifdef __DEEPKS
778-
if (PARAM.inp.deepks_scf)
779-
{
780-
const std::vector<std::vector<TK>>& dm
781-
= dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM()->get_DMK_vector();
782-
783-
this->dpks_cal_e_delta_band(dm);
784-
}
785-
#endif
786-
787-
// 8) for delta spin
788-
if (PARAM.inp.sc_mag_switch)
789-
{
790-
SpinConstrain<TK, base_device::DEVICE_CPU>& sc = SpinConstrain<TK, base_device::DEVICE_CPU>::getScInstance();
791-
sc.cal_MW(iter, this->p_hamilt);
792-
}
793-
794-
// 9) use new charge density to calculate energy
795-
this->pelec->cal_energies(1);
796-
797733
// 10) symmetrize the charge density
798734
Symmetry_rho srho;
799735
for (int is = 0; is < PARAM.inp.nspin; is++)
800736
{
801737
srho.begin(is, *(this->pelec->charge), this->pw_rho, GlobalC::ucell.symm);
802738
}
803739

804-
// 11) compute magnetization, only for spin==2
805-
GlobalC::ucell.magnet.compute_magnetization(this->pelec->charge->nrxx,
806-
this->pelec->charge->nxyz,
807-
this->pelec->charge->rho,
808-
this->pelec->nelec_spin.data());
809-
810740
// 12) calculate delta energy
811741
this->pelec->f_en.deband = this->pelec->cal_delta_eband();
812742
}
@@ -923,6 +853,43 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(const int istep, int& iter)
923853
{
924854
ModuleBase::TITLE("ESolver_KS_LCAO", "iter_finish");
925855

856+
// 6) calculate the local occupation number matrix and energy correction in
857+
// DFT+U
858+
if (PARAM.inp.dft_plus_u)
859+
{
860+
// only old DFT+U method should calculated energy correction in esolver,
861+
// new DFT+U method will calculate energy in calculating Hamiltonian
862+
if (PARAM.inp.dft_plus_u == 2)
863+
{
864+
if (GlobalC::dftu.omc != 2)
865+
{
866+
const std::vector<std::vector<TK>>& tmp_dm
867+
= dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM()->get_DMK_vector();
868+
this->dftu_cal_occup_m(iter, tmp_dm);
869+
}
870+
GlobalC::dftu.cal_energy_correction(istep);
871+
}
872+
GlobalC::dftu.output();
873+
}
874+
875+
// (7) for deepks, calculate delta_e
876+
#ifdef __DEEPKS
877+
if (PARAM.inp.deepks_scf)
878+
{
879+
const std::vector<std::vector<TK>>& dm
880+
= dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM()->get_DMK_vector();
881+
882+
this->dpks_cal_e_delta_band(dm);
883+
}
884+
#endif
885+
886+
// 8) for delta spin
887+
if (PARAM.inp.sc_mag_switch)
888+
{
889+
SpinConstrain<TK, base_device::DEVICE_CPU>& sc = SpinConstrain<TK, base_device::DEVICE_CPU>::getScInstance();
890+
sc.cal_MW(iter, this->p_hamilt);
891+
}
892+
926893
// call iter_finish() of ESolver_KS
927894
ESolver_KS<TK>::iter_finish(istep, iter);
928895

source/module_esolver/esolver_ks_lcao.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ class ESolver_KS_LCAO : public ESolver_KS<TK> {
5050

5151
virtual void iter_init(const int istep, const int iter) override;
5252

53-
virtual void hamilt2density(const int istep,
54-
const int iter,
55-
const double ethr) override;
53+
virtual void hamilt2density_single(const int istep, const int iter, const double ethr) override;
5654

5755
virtual void update_pot(const int istep, const int iter) override;
5856

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,8 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(const Input_para& inp, UnitCell&
118118
this->pelec_td = dynamic_cast<elecstate::ElecStateLCAO_TDDFT*>(this->pelec);
119119
}
120120

121-
void ESolver_KS_LCAO_TDDFT::hamilt2density(const int istep, const int iter, const double ethr)
121+
void ESolver_KS_LCAO_TDDFT::hamilt2density_single(const int istep, const int iter, const double ethr)
122122
{
123-
pelec->charge->save_rho_before_sum_band();
124-
125123
if (wf.init_wfc == "file")
126124
{
127125
if (istep >= 1)
@@ -171,11 +169,23 @@ void ESolver_KS_LCAO_TDDFT::hamilt2density(const int istep, const int iter, cons
171169
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec_td, false);
172170
}
173171
}
174-
// else
175-
// {
176-
// ModuleBase::WARNING_QUIT("ESolver_KS_LCAO", "HSolver has not been initialed!");
177-
// }
178172

173+
// symmetrize the charge density only for ground state
174+
if (istep <= 1)
175+
{
176+
Symmetry_rho srho;
177+
for (int is = 0; is < PARAM.inp.nspin; is++)
178+
{
179+
srho.begin(is, *(pelec->charge), pw_rho, GlobalC::ucell.symm);
180+
}
181+
}
182+
183+
// (7) calculate delta energy
184+
this->pelec->f_en.deband = this->pelec->cal_delta_eband();
185+
}
186+
187+
void ESolver_KS_LCAO_TDDFT::iter_finish(const int istep, int& iter)
188+
{
179189
// print occupation of each band
180190
if (iter == 1 && istep <= 2)
181191
{
@@ -201,32 +211,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2density(const int istep, const int iter, cons
201211
<< std::endl;
202212
}
203213

204-
for (int ik = 0; ik < kv.get_nks(); ++ik)
205-
{
206-
this->pelec_td->print_band(ik, PARAM.inp.printe, iter);
207-
}
208-
209-
// using new charge density.
210-
this->pelec->cal_energies(1);
211-
212-
// symmetrize the charge density only for ground state
213-
if (istep <= 1)
214-
{
215-
Symmetry_rho srho;
216-
for (int is = 0; is < PARAM.inp.nspin; is++)
217-
{
218-
srho.begin(is, *(pelec->charge), pw_rho, GlobalC::ucell.symm);
219-
}
220-
}
221-
222-
// (6) compute magnetization, only for spin==2
223-
GlobalC::ucell.magnet.compute_magnetization(this->pelec->charge->nrxx,
224-
this->pelec->charge->nxyz,
225-
this->pelec->charge->rho,
226-
pelec->nelec_spin.data());
227-
228-
// (7) calculate delta energy
229-
this->pelec->f_en.deband = this->pelec->cal_delta_eband();
214+
ESolver_KS_LCAO<std::complex<double>, double>::iter_finish(istep, iter);
230215
}
231216

232217
void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)

source/module_esolver/esolver_ks_lcao_tddft.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, doubl
3030
int td_htype = 1;
3131

3232
protected:
33-
virtual void hamilt2density(const int istep, const int iter, const double ethr) override;
33+
virtual void hamilt2density_single(const int istep, const int iter, const double ethr) override;
3434

3535
virtual void update_pot(const int istep, const int iter) override;
3636

37+
virtual void iter_finish(const int istep, int& iter) override;
38+
3739
virtual void after_scf(const int istep) override;
3840

3941
void cal_edm_tddft();

0 commit comments

Comments
 (0)