Skip to content

Commit 38f89fc

Browse files
committed
Refactor: refactor iter_finish and iter_init
1 parent cc9d759 commit 38f89fc

File tree

4 files changed

+121
-126
lines changed

4 files changed

+121
-126
lines changed

source/module_esolver/esolver_ks.cpp

Lines changed: 75 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -427,49 +427,11 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
427427
this->niter = this->maxniter;
428428

429429
// 4) SCF iterations
430-
double diag_ethr = PARAM.inp.pw_diag_thr;
430+
this->diag_ethr = PARAM.inp.pw_diag_thr;
431431

432432
std::cout << " * * * * * *\n << Start SCF iteration." << std::endl;
433433
for (int iter = 1; iter <= this->maxniter; ++iter)
434434
{
435-
// 5) write head
436-
ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname);
437-
438-
#ifdef __MPI
439-
auto iterstart = MPI_Wtime();
440-
#else
441-
auto iterstart = std::chrono::system_clock::now();
442-
#endif
443-
444-
if (PARAM.inp.esolver_type == "ksdft")
445-
{
446-
diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type,
447-
PARAM.inp.esolver_type,
448-
PARAM.inp.calculation,
449-
PARAM.inp.init_chg,
450-
PARAM.inp.precision,
451-
istep,
452-
iter,
453-
drho,
454-
PARAM.inp.pw_diag_thr,
455-
diag_ethr,
456-
PARAM.inp.nelec);
457-
}
458-
else if (PARAM.inp.esolver_type == "sdft")
459-
{
460-
diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type,
461-
PARAM.inp.esolver_type,
462-
PARAM.inp.calculation,
463-
PARAM.inp.init_chg,
464-
istep,
465-
iter,
466-
drho,
467-
PARAM.inp.pw_diag_thr,
468-
diag_ethr,
469-
PARAM.inp.nbands,
470-
esolver_KS_ne);
471-
}
472-
473435
// 6) initialization of SCF iterations
474436
this->iter_init(istep, iter);
475437

@@ -610,46 +572,13 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
610572

611573
// 10) finish scf iterations
612574
this->iter_finish(istep, iter);
613-
#ifdef __MPI
614-
double duration = (double)(MPI_Wtime() - iterstart);
615-
#else
616-
double duration
617-
= (std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now() - iterstart))
618-
.count()
619-
/ static_cast<double>(1e6);
620-
#endif
621-
622-
// 11) get mtaGGA related parameters
623-
double dkin = 0.0; // for meta-GGA
624-
if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
625-
{
626-
dkin = p_chgmix->get_dkin(pelec->charge, PARAM.inp.nelec);
627-
}
628-
this->pelec->print_etot(this->conv_esolver, iter, drho, dkin, duration, PARAM.inp.printe, diag_ethr);
629-
630-
// 12) Json, need to be moved to somewhere else
631-
#ifdef __RAPIDJSON
632-
// add Json of scf mag
633-
Json::add_output_scf_mag(GlobalC::ucell.magnet.tot_magnetization,
634-
GlobalC::ucell.magnet.abs_magnetization,
635-
this->pelec->f_en.etot * ModuleBase::Ry_to_eV,
636-
this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV,
637-
drho,
638-
duration);
639-
#endif //__RAPIDJSON
640575

641576
// 13) check convergence
642577
if (this->conv_esolver)
643578
{
644579
this->niter = iter;
645580
break;
646581
}
647-
648-
// notice for restart
649-
if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax)
650-
{
651-
std::cout << " SCF restart after this step!" << std::endl;
652-
}
653582
} // end scf iterations
654583
std::cout << " >> Leave SCF iteration.\n * * * * * *" << std::endl;
655584

@@ -662,6 +591,47 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
662591
return;
663592
};
664593

594+
template <typename T, typename Device>
595+
void ESolver_KS<T, Device>::iter_init(const int istep, const int iter)
596+
{
597+
ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname);
598+
599+
#ifdef __MPI
600+
iter_time = MPI_Wtime();
601+
#else
602+
iter_time = std::chrono::system_clock::now();
603+
#endif
604+
605+
if (PARAM.inp.esolver_type == "ksdft")
606+
{
607+
diag_ethr = hsolver::set_diagethr_ks(PARAM.inp.basis_type,
608+
PARAM.inp.esolver_type,
609+
PARAM.inp.calculation,
610+
PARAM.inp.init_chg,
611+
PARAM.inp.precision,
612+
istep,
613+
iter,
614+
drho,
615+
PARAM.inp.pw_diag_thr,
616+
diag_ethr,
617+
PARAM.inp.nelec);
618+
}
619+
else if (PARAM.inp.esolver_type == "sdft")
620+
{
621+
diag_ethr = hsolver::set_diagethr_sdft(PARAM.inp.basis_type,
622+
PARAM.inp.esolver_type,
623+
PARAM.inp.calculation,
624+
PARAM.inp.init_chg,
625+
istep,
626+
iter,
627+
drho,
628+
PARAM.inp.pw_diag_thr,
629+
diag_ethr,
630+
PARAM.inp.nbands,
631+
esolver_KS_ne);
632+
}
633+
}
634+
665635
template <typename T, typename Device>
666636
void ESolver_KS<T, Device>::iter_finish(const int istep, int& iter)
667637
{
@@ -675,6 +645,39 @@ void ESolver_KS<T, Device>::iter_finish(const int istep, int& iter)
675645
}
676646
this->pelec->f_en.etot_delta = this->pelec->f_en.etot - this->pelec->f_en.etot_old;
677647
this->pelec->f_en.etot_old = this->pelec->f_en.etot;
648+
649+
#ifdef __MPI
650+
double duration = (double)(MPI_Wtime() - iter_time);
651+
#else
652+
double duration
653+
= (std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now() - iter_time)).count()
654+
/ static_cast<double>(1e6);
655+
#endif
656+
657+
// get mtaGGA related parameters
658+
double dkin = 0.0; // for meta-GGA
659+
if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
660+
{
661+
dkin = p_chgmix->get_dkin(pelec->charge, PARAM.inp.nelec);
662+
}
663+
this->pelec->print_etot(this->conv_esolver, iter, drho, dkin, duration, PARAM.inp.printe, diag_ethr);
664+
665+
// Json, need to be moved to somewhere else
666+
#ifdef __RAPIDJSON
667+
// add Json of scf mag
668+
Json::add_output_scf_mag(GlobalC::ucell.magnet.tot_magnetization,
669+
GlobalC::ucell.magnet.abs_magnetization,
670+
this->pelec->f_en.etot * ModuleBase::Ry_to_eV,
671+
this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV,
672+
drho,
673+
duration);
674+
#endif //__RAPIDJSON
675+
676+
// notice for restart
677+
if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax)
678+
{
679+
std::cout << " SCF restart after this step!" << std::endl;
680+
}
678681
}
679682

680683
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
@@ -689,13 +692,6 @@ void ESolver_KS<T, Device>::after_scf(const int istep)
689692
{
690693
this->pelec->print_eigenvalue(GlobalV::ofs_running);
691694
}
692-
// #ifdef __RAPIDJSON
693-
// // add Json of efermi energy converge
694-
// Json::add_output_efermi_converge(this->pelec->eferm.ef * ModuleBase::Ry_to_eV, this->conv_esolver);
695-
// // add nkstot,nkstot_ibz to output json
696-
// int Jnkstot = this->pelec->klist->get_nkstot();
697-
// Json::add_nkstot(Jnkstot);
698-
// #endif //__RAPIDJSON
699695
}
700696

701697
//------------------------------------------------------------------------------

source/module_esolver/esolver_ks.h

Lines changed: 40 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,70 +18,63 @@ namespace ModuleESolver
1818
template <typename T, typename Device = base_device::DEVICE_CPU>
1919
class ESolver_KS : public ESolver_FP
2020
{
21-
public:
21+
public:
22+
//! Constructor
23+
ESolver_KS();
2224

23-
//! Constructor
24-
ESolver_KS();
25+
//! Deconstructor
26+
virtual ~ESolver_KS();
2527

26-
//! Deconstructor
27-
virtual ~ESolver_KS();
28+
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override;
2829

29-
double scf_thr; // scf density threshold
30+
virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09
3031

31-
double scf_ene_thr; // scf energy threshold
32+
virtual void runner(const int istep, UnitCell& cell) override;
3233

33-
double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver)
34+
// calculate electron density from a specific Hamiltonian
35+
virtual void hamilt2density(const int istep, const int iter, const double ethr);
3436

35-
int maxniter; // maximum iter steps for scf
37+
// calculate electron states from a specific Hamiltonian
38+
virtual void hamilt2estates(const double ethr) {};
3639

37-
int niter; // iter steps actually used in scf
40+
protected:
41+
//! Something to do before SCF iterations.
42+
virtual void before_scf(const int istep) {};
3843

39-
int out_freq_elec; // frequency for output
44+
//! Something to do before hamilt2density function in each iter loop.
45+
virtual void iter_init(const int istep, const int iter);
4046

41-
virtual void before_all_runners(const Input_para& inp, UnitCell& cell) override;
47+
//! Something to do after hamilt2density function in each iter loop.
48+
virtual void iter_finish(const int istep, int& iter);
4249

43-
virtual void init_after_vc(const Input_para& inp, UnitCell& cell) override; // liuyu add 2023-03-09
50+
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
51+
virtual void after_scf(const int istep) override;
4452

45-
virtual void runner(const int istep, UnitCell& cell) override;
53+
//! <Temporary> It should be replaced by a function in Hamilt Class
54+
virtual void update_pot(const int istep, const int iter) {};
4655

47-
// calculate electron density from a specific Hamiltonian
48-
virtual void hamilt2density(const int istep, const int iter, const double ethr);
56+
//! Hamiltonian
57+
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;
4958

50-
// calculate electron states from a specific Hamiltonian
51-
virtual void hamilt2estates(const double ethr){};
59+
ModulePW::PW_Basis_K* pw_wfc = nullptr;
5260

53-
protected:
54-
//! Something to do before SCF iterations.
55-
virtual void before_scf(const int istep) {};
61+
Charge_Mixing* p_chgmix = nullptr;
5662

57-
//! Something to do before hamilt2density function in each iter loop.
58-
virtual void iter_init(const int istep, const int iter) {};
63+
wavefunc wf;
5964

60-
//! Something to do after hamilt2density function in each iter loop.
61-
virtual void iter_finish(const int istep, int& iter);
65+
// wavefunction coefficients
66+
psi::Psi<T>* psi = nullptr;
6267

63-
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
64-
virtual void after_scf(const int istep) override;
65-
66-
//! <Temporary> It should be replaced by a function in Hamilt Class
67-
virtual void update_pot(const int istep, const int iter) {};
68-
69-
protected:
70-
//! Hamiltonian
71-
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;
72-
73-
ModulePW::PW_Basis_K* pw_wfc = nullptr;
74-
75-
Charge_Mixing* p_chgmix = nullptr;
76-
77-
wavefunc wf;
78-
79-
// wavefunction coefficients
80-
psi::Psi<T>* psi = nullptr;
81-
82-
protected:
83-
std::string basisname; // PW or LCAO
84-
double esolver_KS_ne = 0.0;
68+
std::string basisname; // PW or LCAO
69+
double esolver_KS_ne = 0.0;
70+
double iter_time; // the start time of scf iteration
71+
double diag_ethr; // the threshold for diagonalization
72+
double scf_thr; // scf density threshold
73+
double scf_ene_thr; // scf energy threshold
74+
double drho; // the difference between rho_in (before HSolver) and rho_out (After HSolver)
75+
int maxniter; // maximum iter steps for scf
76+
int niter; // iter steps actually used in scf
77+
int out_freq_elec; // frequency for output
8578
};
8679
} // end of namespace
8780
#endif

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,9 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
519519
{
520520
ModuleBase::TITLE("ESolver_KS_LCAO", "iter_init");
521521

522+
// call iter_init() of ESolver_KS
523+
ESolver_KS<TK>::iter_init(istep, iter);
524+
522525
if (iter == 1)
523526
{
524527
this->p_chgmix->init_mixing(); // init mixing

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,9 @@ void ESolver_KS_PW<T, Device>::before_scf(const int istep)
305305
template <typename T, typename Device>
306306
void ESolver_KS_PW<T, Device>::iter_init(const int istep, const int iter)
307307
{
308+
// call iter_init() of ESolver_KS
309+
ESolver_KS<T>::iter_init(istep, iter);
310+
308311
if (iter == 1)
309312
{
310313
this->p_chgmix->init_mixing();

0 commit comments

Comments
 (0)