Skip to content

Commit 98c5f03

Browse files
committed
Refactor: add two-level scf for EXX in ESolver_KS_LCAO, fixed some bugs in printing
1 parent e507003 commit 98c5f03

20 files changed

+94
-1304
lines changed

source/Makefile.Objects

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,7 @@ ORB_table_alpha.o\
136136
ORB_gen_tables.o\
137137
local_orbital_wfc.o\
138138
local_orbital_charge.o\
139-
ELEC_cbands_k.o\
140-
ELEC_cbands_gamma.o\
141139
ELEC_evolve.o\
142-
ELEC_scf.o\
143-
ELEC_nscf.o\
144140
LOOP_cell.o\
145141
LOOP_ions.o\
146142
run_md_lcao.o\

source/module_elecstate/elecstate_pw.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
#include "module_base/constants.h"
44
#include "src_parallel/parallel_reduce.h"
55
#include "src_pw/global.h"
6+
#include "module_base/timer.h"
67

78
namespace elecstate
89
{
910

1011
void ElecStatePW::psiToRho(const psi::Psi<std::complex<double>>& psi)
1112
{
13+
ModuleBase::TITLE("ElecStatePW", "psiToRho");
14+
ModuleBase::timer::tick("ElecStatePW", "psiToRho");
1215
this->calculate_weights();
1316

1417
this->calEBand();
@@ -28,6 +31,7 @@ void ElecStatePW::psiToRho(const psi::Psi<std::complex<double>>& psi)
2831
this->updateRhoK(psi);
2932
}
3033
this->parallelK();
34+
ModuleBase::timer::tick("ElecStatePW", "psiToRho");
3135
}
3236

3337
void ElecStatePW::updateRhoK(const psi::Psi<std::complex<double>>& psi)

source/module_esolver/esolver_ks.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,7 @@ namespace ModuleESolver
151151
void ESolver_KS::Run(const int istep, UnitCell_pseudo& ucell)
152152
{
153153
if (!(GlobalV::CALCULATION == "scf" || GlobalV::CALCULATION == "md"
154-
|| GlobalV::CALCULATION == "relax" || GlobalV::CALCULATION == "cell-relax" || GlobalV::CALCULATION.substr(0,3) == "sto")
155-
#ifdef __MPI
156-
|| Exx_Global::Hybrid_Type::No != GlobalC::exx_global.info.hybrid_type
157-
#endif
158-
)
154+
|| GlobalV::CALCULATION == "relax" || GlobalV::CALCULATION == "cell-relax" || GlobalV::CALCULATION.substr(0,3) == "sto"))
159155
{
160156
this->othercalculation(istep);
161157
}
@@ -172,8 +168,11 @@ namespace ModuleESolver
172168
for (int iter = 1; iter <= this->maxniter; ++iter)
173169
{
174170
writehead(GlobalV::ofs_running, istep, iter);
175-
clock_t iterstart, iterend;
176-
iterstart = std::clock();
171+
#ifdef __MPI
172+
auto iterstart = MPI_Wtime();
173+
#else
174+
auto iterstart = std::chrono::system_clock::now();
175+
#endif
177176
set_ethr(istep, iter);
178177
eachiterinit(istep, iter);
179178
this->hamilt2density(istep, iter, this->diag_ethr);
@@ -227,13 +226,17 @@ namespace ModuleESolver
227226
// this->phamilt->update(conv_elec);
228227
updatepot(istep, iter);
229228
eachiterfinish(iter);
230-
iterend = std::clock();
231-
double duration = double(iterend - iterstart) / CLOCKS_PER_SEC;
229+
#ifdef __MPI
230+
double duration = (double)(MPI_Wtime() - iterstart);
231+
#else
232+
double duration = (std::chrono::system_clock::now() - iterstart).count() / CLOCKS_PER_SEC;
233+
#endif
232234
printiter(iter, drho, duration, diag_ethr);
233235
if (this->conv_elec)
234236
{
237+
int stop = this->do_after_converge(iter);
235238
this->niter = iter;
236-
break;
239+
if(stop) break;
237240
}
238241
}
239242
afterscf();

source/module_esolver/esolver_ks.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ namespace ModuleESolver
4949
virtual void afterscf() {};
5050
// <Temporary> It should be replaced by a function in Hamilt Class
5151
virtual void updatepot(const int istep, const int iter) {};
52+
// choose strategy when charge density convergence achieved
53+
virtual bool do_after_converge(int& iter){this->niter = iter; return true;}
5254

5355

5456
//TOOLS:

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
#include "input_update.h"
1010
#include "src_pw/occupy.h"
1111
#include "src_lcao/ELEC_evolve.h"
12-
#include "src_lcao/ELEC_cbands_gamma.h"
13-
#include "src_lcao/ELEC_cbands_k.h"
1412
#include "src_pw/symmetry_rho.h"
1513
#include "src_io/chi0_hilbert.h"
1614
#include "src_pw/threshold_elec.h"
@@ -918,4 +916,30 @@ namespace ModuleESolver
918916

919917
}
920918

919+
bool ESolver_KS_LCAO::do_after_converge(int& iter)
920+
{
921+
#ifdef __MPI
922+
if(Exx_Global::Hybrid_Type::No != GlobalC::exx_global.info.hybrid_type)
923+
{
924+
if (!GlobalC::exx_global.info.separate_loop)
925+
{
926+
GlobalC::exx_global.info.hybrid_step = 1;
927+
}
928+
//exx converged or get max exx steps
929+
if(this->two_level_step == GlobalC::exx_global.info.hybrid_step || (iter==1 && this->two_level_step!=0))
930+
{
931+
return true;
932+
}
933+
//update exx and redo scf
934+
XC_Functional::set_xc_type(GlobalC::ucell.atoms[0].xc_func);
935+
GlobalC::exx_lcao.cal_exx_elec(this->LOC, this->LOWF.wfc_k_grid);
936+
937+
iter = 0;
938+
this->two_level_step++;
939+
return false;
940+
}
941+
#endif // __MPI
942+
return true;
943+
}
944+
921945
}

source/module_esolver/esolver_ks_lcao.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ namespace ModuleESolver
3232
virtual void updatepot(const int istep, const int iter) override;
3333
virtual void eachiterfinish(const int iter) override;
3434
virtual void afterscf() override;
35+
virtual bool do_after_converge(int& iter) override;
36+
int two_level_step = 0;
3537

3638
virtual void othercalculation(const int istep)override;
3739
ORB_control orb_con; //Basis_LCAO

source/module_esolver/esolver_ks_lcao_elec.cpp

Lines changed: 24 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010
#include "../module_neighbor/sltk_atom_arrange.h"
1111
#include "../src_io/istate_charge.h"
1212
#include "../src_io/istate_envelope.h"
13-
#include "src_lcao/ELEC_scf.h"
14-
#include "src_lcao/ELEC_nscf.h"
15-
#include "src_lcao/ELEC_cbands_gamma.h"
16-
#include "src_lcao/ELEC_cbands_k.h"
1713
#include "src_lcao/ELEC_evolve.h"
1814
//
1915
#include "../src_ri/exx_abfs.h"
@@ -205,7 +201,27 @@ namespace ModuleESolver
205201
}
206202
}
207203
#endif
204+
205+
//Peize Lin add 2016-12-03
206+
#ifdef __MPI
207+
if(Exx_Global::Hybrid_Type::No != GlobalC::exx_global.info.hybrid_type)
208+
{
209+
if (Exx_Global::Hybrid_Type::HF == GlobalC::exx_lcao.info.hybrid_type
210+
|| Exx_Global::Hybrid_Type::PBE0 == GlobalC::exx_lcao.info.hybrid_type
211+
|| Exx_Global::Hybrid_Type::HSE == GlobalC::exx_lcao.info.hybrid_type)
212+
{
213+
GlobalC::exx_lcao.cal_exx_ions(*this->LOWF.ParaV);
214+
}
215+
if (Exx_Global::Hybrid_Type::Generate_Matrix == GlobalC::exx_global.info.hybrid_type)
216+
{
217+
Exx_Opt_Orb exx_opt_orb;
218+
exx_opt_orb.generate_matrix();
219+
ModuleBase::timer::tick("LOOP_ions", "opt_ions");
220+
return;
221+
}
222+
}
208223
}
224+
#endif
209225

210226
void ESolver_KS_LCAO::beforescf(int istep)
211227
{
@@ -226,6 +242,9 @@ namespace ModuleESolver
226242

227243
phami->non_first_scf = istep;
228244

245+
// for exx two_level scf
246+
this->two_level_step = 0;
247+
229248
ModuleBase::timer::tick("ESolver_KS_LCAO", "beforescf");
230249
return;
231250
}
@@ -236,53 +255,7 @@ namespace ModuleESolver
236255
ModuleBase::timer::tick("ESolver_KS_LCAO", "othercalculation");
237256
this->beforesolver(istep);
238257
// self consistent calculations for electronic ground state
239-
if (GlobalV::CALCULATION == "scf" || GlobalV::CALCULATION == "md"
240-
|| GlobalV::CALCULATION == "relax" || GlobalV::CALCULATION == "cell-relax") //pengfei 2014-10-13
241-
{
242-
#ifdef __MPI
243-
//Peize Lin add 2016-12-03
244-
if (Exx_Global::Hybrid_Type::HF == GlobalC::exx_lcao.info.hybrid_type
245-
|| Exx_Global::Hybrid_Type::PBE0 == GlobalC::exx_lcao.info.hybrid_type
246-
|| Exx_Global::Hybrid_Type::HSE == GlobalC::exx_lcao.info.hybrid_type)
247-
{
248-
GlobalC::exx_lcao.cal_exx_ions(*this->LOWF.ParaV);
249-
}
250-
if (Exx_Global::Hybrid_Type::Generate_Matrix == GlobalC::exx_global.info.hybrid_type)
251-
{
252-
Exx_Opt_Orb exx_opt_orb;
253-
exx_opt_orb.generate_matrix();
254-
}
255-
else // Peize Lin add 2016-12-03
256-
{
257-
#endif // __MPI
258-
ELEC_scf es;
259-
es.scf(istep, this->LOC, this->LOWF, this->UHM);
260-
#ifdef __MPI
261-
if (GlobalC::exx_global.info.separate_loop)
262-
{
263-
for (size_t hybrid_step = 0; hybrid_step != GlobalC::exx_global.info.hybrid_step; ++hybrid_step)
264-
{
265-
XC_Functional::set_xc_type(GlobalC::ucell.atoms[0].xc_func);
266-
GlobalC::exx_lcao.cal_exx_elec(this->LOC, this->LOWF.wfc_k_grid);
267-
268-
ELEC_scf es;
269-
es.scf(istep, this->LOC, this->LOWF, this->UHM);
270-
if (ELEC_scf::iter == 1) // exx converge
271-
{
272-
break;
273-
}
274-
}
275-
}
276-
else
277-
{
278-
XC_Functional::set_xc_type(GlobalC::ucell.atoms[0].xc_func);
279-
ELEC_scf es;
280-
es.scf(istep, this->LOC, this->LOWF, this->UHM);
281-
}
282-
}
283-
#endif // __MPI
284-
}
285-
else if (GlobalV::CALCULATION == "nscf")
258+
if (GlobalV::CALCULATION == "nscf")
286259
{
287260
this->nscf();
288261
}

source/module_hamilt/hamilt_lcao.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ template <> void HamiltLCAO<double>::matrix(MatrixBlock<double> &hk_in, MatrixBl
5555
template <> void HamiltLCAO<double>::updateHk(const int ik)
5656
{
5757
ModuleBase::TITLE("HamiltLCAO", "updateHk");
58+
ModuleBase::timer::tick("HamiltLCAO", "updateHk");
5859
//this->hk_fixed_mock(ik);
5960
//this->hk_update_mock(ik);
6061

@@ -121,6 +122,7 @@ template <> void HamiltLCAO<double>::updateHk(const int ik)
121122
BlasConnector::copy(this->LM->Sloc.size(), this->LM->Sloc.data(), inc, this->smatrix_k, inc);
122123
hsolver::DiagoElpa::is_already_decomposed = false;
123124
}
125+
ModuleBase::timer::tick("HamiltLCAO", "updateHk");
124126
return;
125127
}
126128

@@ -136,7 +138,7 @@ void HamiltLCAO<double>::constructHamilt()
136138
template <> void HamiltLCAO<std::complex<double>>::updateHk(const int ik)
137139
{
138140
ModuleBase::TITLE("HamiltLCAO", "updateHk");
139-
ModuleBase::timer::tick("Efficience", "each_k");
141+
ModuleBase::timer::tick("HamiltLCAO", "each_k");
140142
//-----------------------------------------
141143
//(1) prepare data for this k point.
142144
// copy the local potential from array.
@@ -200,8 +202,8 @@ template <> void HamiltLCAO<std::complex<double>>::updateHk(const int ik)
200202
//--------------------------------------------
201203

202204
// with k points
203-
ModuleBase::timer::tick("Efficience", "each_k");
204-
ModuleBase::timer::tick("Efficience", "H_k");
205+
ModuleBase::timer::tick("HamiltLCAO", "each_k");
206+
ModuleBase::timer::tick("HamiltLCAO", "H_k");
205207
this->uhm->calculate_Hk(ik);
206208

207209
// Effective potential of DFT+U is added to total Hamiltonian here; Quxin adds on 20201029

source/module_hsolver/hsolver_lcao.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "diago_blas.h"
44
#include "diago_elpa.h"
55
#include "src_io/write_HS.h"
6+
#include "module_base/timer.h"
67

78
namespace hsolver
89
{
@@ -11,6 +12,7 @@ template <typename T>
1112
void HSolverLCAO::solveTemplate(hamilt::Hamilt* pHamilt, psi::Psi<T>& psi, elecstate::ElecState* pes, const std::string method_in, const bool skip_charge)
1213
{
1314
ModuleBase::TITLE("HSolverLCAO", "solve");
15+
ModuleBase::timer::tick("HSolverLCAO", "solve");
1416
// select the method of diagonalization
1517
this->method = method_in;
1618
if (this->method == "genelpa")
@@ -78,11 +80,16 @@ void HSolverLCAO::solveTemplate(hamilt::Hamilt* pHamilt, psi::Psi<T>& psi, elecs
7880
}
7981

8082
//used in nscf calculation
81-
if(skip_charge) return;
83+
if(skip_charge)
84+
{
85+
ModuleBase::timer::tick("HSolverLCAO", "solve");
86+
return;
87+
}
8288

8389
//calculate charge by psi
8490
//called in scf calculation
8591
pes->psiToRho(psi);
92+
ModuleBase::timer::tick("HSolverLCAO", "solve");
8693
}
8794

8895
int HSolverLCAO::out_mat_hs = 0;

source/module_hsolver/hsolver_pw.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "diago_cg.h"
44
#include "diago_david.h"
55
#include "module_base/tool_quit.h"
6+
#include "module_base/timer.h"
67
#include "module_elecstate/elecstate_pw.h"
78
#include "src_pw/global.h"
89

@@ -22,6 +23,8 @@ void HSolverPW::update()
2223

2324
void HSolverPW::solve(hamilt::Hamilt* pHamilt, psi::Psi<std::complex<double>>& psi, elecstate::ElecState* pes, const std::string method_in, const bool skip_charge)
2425
{
26+
ModuleBase::TITLE("HSolverPW", "solve");
27+
ModuleBase::timer::tick("HSolverPW", "solve");
2528
// prepare for the precondition of diagonalization
2629
this->precondition.resize(psi.get_nbasis());
2730

@@ -92,9 +95,14 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, psi::Psi<std::complex<double>>& p
9295
pdiagh = nullptr;
9396
}
9497

95-
if(skip_charge) return;
98+
if(skip_charge)
99+
{
100+
ModuleBase::timer::tick("HSolverPW", "solve");
101+
return;
102+
}
96103
pes->psiToRho(psi);
97104

105+
ModuleBase::timer::tick("HSolverPW", "solve");
98106
return;
99107
}
100108

0 commit comments

Comments
 (0)