Skip to content

Commit d81cb46

Browse files
authored
Refactor: refactor iter_finish and iter_init (#5426)
* Refactor: refactor iter_finish and iter_init * update esolver.h * fix some errors
1 parent 0d455cb commit d81cb46

File tree

4 files changed

+137
-137
lines changed

4 files changed

+137
-137
lines changed

source/module_esolver/esolver_ks.cpp

Lines changed: 78 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
#include "esolver_ks.h"
22

3-
#include <ctime>
4-
#include <iostream>
5-
#ifdef __MPI
6-
#include <mpi.h>
7-
#else
8-
#include <chrono>
9-
#endif
103
#include "module_base/timer.h"
114
#include "module_cell/cal_atoms_info.h"
125
#include "module_io/json_output/init_info.h"
@@ -15,6 +8,9 @@
158
#include "module_io/print_info.h"
169
#include "module_io/write_istate_info.h"
1710
#include "module_parameter/parameter.h"
11+
12+
#include <ctime>
13+
#include <iostream>
1814
//--------------Temporary----------------
1915
#include "module_base/global_variable.h"
2016
#include "module_hamilt_lcao/module_dftu/dftu.h"
@@ -427,49 +423,11 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
427423
this->niter = this->maxniter;
428424

429425
// 4) SCF iterations
430-
double diag_ethr = PARAM.inp.pw_diag_thr;
426+
this->diag_ethr = PARAM.inp.pw_diag_thr;
431427

432428
std::cout << " * * * * * *\n << Start SCF iteration." << std::endl;
433429
for (int iter = 1; iter <= this->maxniter; ++iter)
434430
{
435-
// 5) write head
436-
ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname);
437-
438-
#ifdef __MPI
439-
auto iterstart = MPI_Wtime();
440-
#else
441-
auto iterstart = std::chrono::system_clock::now();
442-
#endif
443-
444-
if (PARAM.inp.esolver_type == "ksdft")
445-
{
446-
diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type,
447-
PARAM.inp.esolver_type,
448-
PARAM.inp.calculation,
449-
PARAM.inp.init_chg,
450-
PARAM.inp.precision,
451-
istep,
452-
iter,
453-
drho,
454-
PARAM.inp.pw_diag_thr,
455-
diag_ethr,
456-
PARAM.inp.nelec);
457-
}
458-
else if (PARAM.inp.esolver_type == "sdft")
459-
{
460-
diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type,
461-
PARAM.inp.esolver_type,
462-
PARAM.inp.calculation,
463-
PARAM.inp.init_chg,
464-
istep,
465-
iter,
466-
drho,
467-
PARAM.inp.pw_diag_thr,
468-
diag_ethr,
469-
PARAM.inp.nbands,
470-
esolver_KS_ne);
471-
}
472-
473431
// 6) initialization of SCF iterations
474432
this->iter_init(istep, iter);
475433

@@ -615,33 +573,6 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
615573

616574
// 10) finish scf iterations
617575
this->iter_finish(istep, iter);
618-
#ifdef __MPI
619-
double duration = (double)(MPI_Wtime() - iterstart);
620-
#else
621-
double duration
622-
= (std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now() - iterstart))
623-
.count()
624-
/ static_cast<double>(1e6);
625-
#endif
626-
627-
// 11) get mtaGGA related parameters
628-
double dkin = 0.0; // for meta-GGA
629-
if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
630-
{
631-
dkin = p_chgmix->get_dkin(pelec->charge, PARAM.inp.nelec);
632-
}
633-
this->pelec->print_etot(this->conv_esolver, iter, drho, dkin, duration, PARAM.inp.printe, diag_ethr);
634-
635-
// 12) Json, need to be moved to somewhere else
636-
#ifdef __RAPIDJSON
637-
// add Json of scf mag
638-
Json::add_output_scf_mag(GlobalC::ucell.magnet.tot_magnetization,
639-
GlobalC::ucell.magnet.abs_magnetization,
640-
this->pelec->f_en.etot * ModuleBase::Ry_to_eV,
641-
this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV,
642-
drho,
643-
duration);
644-
#endif //__RAPIDJSON
645576

646577
// 13) check convergence
647578
if (this->conv_esolver || this->oscillate_esolver)
@@ -653,12 +584,6 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
653584
}
654585
break;
655586
}
656-
657-
// notice for restart
658-
if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax)
659-
{
660-
std::cout << " SCF restart after this step!" << std::endl;
661-
}
662587
} // end scf iterations
663588
std::cout << " >> Leave SCF iteration.\n * * * * * *" << std::endl;
664589

@@ -671,6 +596,47 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
671596
return;
672597
};
673598

599+
template <typename T, typename Device>
600+
void ESolver_KS<T, Device>::iter_init(const int istep, const int iter)
601+
{
602+
ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname);
603+
604+
#ifdef __MPI
605+
iter_time = MPI_Wtime();
606+
#else
607+
iter_time = std::chrono::system_clock::now();
608+
#endif
609+
610+
if (PARAM.inp.esolver_type == "ksdft")
611+
{
612+
diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type,
613+
PARAM.inp.esolver_type,
614+
PARAM.inp.calculation,
615+
PARAM.inp.init_chg,
616+
PARAM.inp.precision,
617+
istep,
618+
iter,
619+
drho,
620+
PARAM.inp.pw_diag_thr,
621+
diag_ethr,
622+
PARAM.inp.nelec);
623+
}
624+
else if (PARAM.inp.esolver_type == "sdft")
625+
{
626+
diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type,
627+
PARAM.inp.esolver_type,
628+
PARAM.inp.calculation,
629+
PARAM.inp.init_chg,
630+
istep,
631+
iter,
632+
drho,
633+
PARAM.inp.pw_diag_thr,
634+
diag_ethr,
635+
PARAM.inp.nbands,
636+
esolver_KS_ne);
637+
}
638+
}
639+
674640
template <typename T, typename Device>
675641
void ESolver_KS<T, Device>::iter_finish(const int istep, int& iter)
676642
{
@@ -684,6 +650,39 @@ void ESolver_KS<T, Device>::iter_finish(const int istep, int& iter)
684650
}
685651
this->pelec->f_en.etot_delta = this->pelec->f_en.etot - this->pelec->f_en.etot_old;
686652
this->pelec->f_en.etot_old = this->pelec->f_en.etot;
653+
654+
#ifdef __MPI
655+
double duration = (double)(MPI_Wtime() - iter_time);
656+
#else
657+
double duration
658+
= (std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now() - iter_time)).count()
659+
/ static_cast<double>(1e6);
660+
#endif
661+
662+
// get mtaGGA related parameters
663+
double dkin = 0.0; // for meta-GGA
664+
if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
665+
{
666+
dkin = p_chgmix->get_dkin(pelec->charge, PARAM.inp.nelec);
667+
}
668+
this->pelec->print_etot(this->conv_esolver, iter, drho, dkin, duration, PARAM.inp.printe, diag_ethr);
669+
670+
// Json, need to be moved to somewhere else
671+
#ifdef __RAPIDJSON
672+
// add Json of scf mag
673+
Json::add_output_scf_mag(GlobalC::ucell.magnet.tot_magnetization,
674+
GlobalC::ucell.magnet.abs_magnetization,
675+
this->pelec->f_en.etot * ModuleBase::Ry_to_eV,
676+
this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV,
677+
drho,
678+
duration);
679+
#endif //__RAPIDJSON
680+
681+
// notice for restart
682+
if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax)
683+
{
684+
std::cout << " SCF restart after this step!" << std::endl;
685+
}
687686
}
688687

689688
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
@@ -698,13 +697,6 @@ void ESolver_KS<T, Device>::after_scf(const int istep)
698697
{
699698
this->pelec->print_eigenvalue(GlobalV::ofs_running);
700699
}
701-
// #ifdef __RAPIDJSON
702-
// // add Json of efermi energy converge
703-
// Json::add_output_efermi_converge(this->pelec->eferm.ef * ModuleBase::Ry_to_eV, this->conv_esolver);
704-
// // add nkstot,nkstot_ibz to output json
705-
// int Jnkstot = this->pelec->klist->get_nkstot();
706-
// Json::add_nkstot(Jnkstot);
707-
// #endif //__RAPIDJSON
708700
}
709701

710702
//------------------------------------------------------------------------------

source/module_esolver/esolver_ks.h

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,79 +10,81 @@
1010
#include "module_io/cal_test.h"
1111
#include "module_psi/psi.h"
1212

13-
#include <fstream>
13+
#ifdef __MPI
14+
#include <mpi.h>
15+
#else
16+
#include <chrono>
17+
#endif
1418
#include <cstring>
19+
#include <fstream>
1520
namespace ModuleESolver
1621
{
1722

1823
template <typename T, typename Device = base_device::DEVICE_CPU>
1924
class ESolver_KS : public ESolver_FP
2025
{
21-
public:
22-
23-
//! Constructor
24-
ESolver_KS();
25-
26-
//! Deconstructor
27-
virtual ~ESolver_KS();
28-
29-
double scf_thr; // scf density threshold
26+
public:
27+
//! Constructor
28+
ESolver_KS();
3029

31-
double scf_ene_thr; // scf energy threshold
30+
//! Deconstructor
31+
virtual ~ESolver_KS();
3232

33-
double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver)
33+
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override;
3434

35-
int maxniter; // maximum iter steps for scf
35+
virtual void runner(const int istep, UnitCell& cell) override;
3636

37-
int niter; // iter steps actually used in scf
37+
protected:
38+
//! Something to do before SCF iterations.
39+
virtual void before_scf(const int istep) {};
3840

39-
int out_freq_elec; // frequency for output
41+
virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09
4042

41-
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override;
43+
//! Something to do before hamilt2density function in each iter loop.
44+
virtual void iter_init(const int istep, const int iter);
4245

43-
virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09
46+
//! Something to do after hamilt2density function in each iter loop.
47+
virtual void iter_finish(const int istep, int& iter);
4448

45-
virtual void runner(const int istep, UnitCell& cell) override;
49+
// calculate electron density from a specific Hamiltonian
50+
virtual void hamilt2density(const int istep, const int iter, const double ethr);
4651

47-
// calculate electron density from a specific Hamiltonian
48-
virtual void hamilt2density(const int istep, const int iter, const double ethr);
52+
// calculate electron states from a specific Hamiltonian
53+
virtual void hamilt2estates(const double ethr) {};
4954

50-
// calculate electron states from a specific Hamiltonian
51-
virtual void hamilt2estates(const double ethr){};
55+
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
56+
virtual void after_scf(const int istep) override;
5257

53-
protected:
54-
//! Something to do before SCF iterations.
55-
virtual void before_scf(const int istep) {};
58+
//! <Temporary> It should be replaced by a function in Hamilt Class
59+
virtual void update_pot(const int istep, const int iter) {};
5660

57-
//! Something to do before hamilt2density function in each iter loop.
58-
virtual void iter_init(const int istep, const int iter) {};
61+
//! Hamiltonian
62+
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;
5963

60-
//! Something to do after hamilt2density function in each iter loop.
61-
virtual void iter_finish(const int istep, int& iter);
64+
ModulePW::PW_Basis_K* pw_wfc = nullptr;
6265

63-
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
64-
virtual void after_scf(const int istep) override;
66+
Charge_Mixing* p_chgmix = nullptr;
6567

66-
//! <Temporary> It should be replaced by a function in Hamilt Class
67-
virtual void update_pot(const int istep, const int iter) {};
68+
wavefunc wf;
6869

69-
protected:
70-
//! Hamiltonian
71-
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;
70+
// wavefunction coefficients
71+
psi::Psi<T>* psi = nullptr;
7272

73-
ModulePW::PW_Basis_K* pw_wfc = nullptr;
74-
75-
Charge_Mixing* p_chgmix = nullptr;
76-
77-
wavefunc wf;
78-
79-
// wavefunction coefficients
80-
psi::Psi<T>* psi = nullptr;
81-
82-
protected:
83-
std::string basisname; // PW or LCAO
84-
double esolver_KS_ne = 0.0;
85-
bool oscillate_esolver = false; // whether esolver is oscillated
86-
};
87-
} // end of namespace
73+
std::string basisname; // PW or LCAO
74+
double esolver_KS_ne = 0.0;
75+
bool oscillate_esolver = false; // whether esolver is oscillated
76+
#ifdef __MPI
77+
double iter_time; // the start time of scf iteration
78+
#else
79+
std::chrono::system_clock::time_point iter_time;
80+
#endif
81+
double diag_ethr; // the threshold for diagonalization
82+
double scf_thr; // scf density threshold
83+
double scf_ene_thr; // scf energy threshold
84+
double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver)
85+
int maxniter; // maximum iter steps for scf
86+
int niter; // iter steps actually used in scf
87+
int out_freq_elec; // frequency for output
88+
};
89+
} // namespace ModuleESolver
8890
#endif

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,9 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
519519
{
520520
ModuleBase::TITLE("ESolver_KS_LCAO", "iter_init");
521521

522+
// call iter_init() of ESolver_KS
523+
ESolver_KS<TK>::iter_init(istep, iter);
524+
522525
if (iter == 1)
523526
{
524527
this->p_chgmix->init_mixing(); // init mixing

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,9 @@ void ESolver_KS_PW<T, Device>::before_scf(const int istep)
305305
template <typename T, typename Device>
306306
void ESolver_KS_PW<T, Device>::iter_init(const int istep, const int iter)
307307
{
308+
// call iter_init() of ESolver_KS
309+
ESolver_KS<T, Device>::iter_init(istep, iter);
310+
308311
if (iter == 1)
309312
{
310313
this->p_chgmix->init_mixing();

0 commit comments

Comments
 (0)