Skip to content

Commit 8a09b8a

Browse files
mohanchenabacus_fixer
andauthored
Try removing template dependency of ESolver (#7014)
* small format changes * refactor(esolver): extract charge density symmetrization to Symmetry_rho::symmetrize_rho - Add static method symmetrize_rho() in Symmetry_rho class - Replace 7 duplicate code blocks with single function call - Simplify code from 35 lines to 7 lines (80% reduction) - Improve code readability and maintainability Modified files: - source_estate/module_charge/symmetry_rho.h: add static method declaration - source_estate/module_charge/symmetry_rho.cpp: implement static method - source_esolver/esolver_ks_lcao.cpp: 2 calls updated - source_esolver/esolver_ks_pw.cpp: 1 call updated - source_esolver/esolver_ks_lcao_tddft.cpp: 1 call updated - source_esolver/esolver_ks_lcaopw.cpp: 1 call updated - source_esolver/esolver_of.cpp: 1 call updated - source_esolver/esolver_sdft_pw.cpp: 1 call updated This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract DeltaSpin lambda loop to deltaspin_lcao module - Create new files deltaspin_lcao.h/cpp in module_deltaspin - Extract DeltaSpin lambda loop logic from ESolver_KS_LCAO - Simplify code from 18 lines to 1 line in hamilt2rho_single - Separate LCAO and PW implementations for DeltaSpin Modified files: - source_esolver/esolver_ks_lcao.cpp: replace inline code with function call - source_lcao/module_deltaspin/CMakeLists.txt: add new source file New files: - source_lcao/module_deltaspin/deltaspin_lcao.h: function declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: function implementation This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): complete DeltaSpin refactoring in LCAO - Add init_deltaspin_lcao() function for DeltaSpin initialization - Add cal_mi_lcao_wrapper() function for magnetic moment calculation - Refactor all DeltaSpin-related code in esolver_ks_lcao.cpp - Simplify code from 29 lines to 3 lines (90% reduction) Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 3 code blocks with function calls - source_lcao/module_deltaspin/deltaspin_lcao.h: add 2 new function declarations - source_lcao/module_deltaspin/deltaspin_lcao.cpp: implement 2 new functions This completes the DeltaSpin refactoring for LCAO method: 1. init_deltaspin_lcao() - initialize DeltaSpin calculation 2. cal_mi_lcao_wrapper() - calculate magnetic moments 3. run_deltaspin_lambda_loop_lcao() - run lambda loop optimization All functions follow the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract DFT+U code to dftu_lcao module - Create new files dftu_lcao.h/cpp in source_lcao directory - Add init_dftu_lcao() function for DFT+U initialization - Add finish_dftu_lcao() function for DFT+U finalization - Simplify code from 32 lines to 2 lines in esolver_ks_lcao.cpp - Remove conditional checks from ESolver, move them to functions Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 2 code blocks with function calls - source_lcao/CMakeLists.txt: add new source file New files: - source_lcao/dftu_lcao.h: function declarations - source_lcao/dftu_lcao.cpp: function implementations This refactoring prepares for unifying old and new DFT+U implementations: - Old DFT+U: source_lcao/module_dftu/ - New DFT+U: source_lcao/module_operator_lcao/op_dftu_lcao.cpp All functions follow ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract diagonalization parameters setup to hsolver module - Create new files diago_params.h/cpp in source_hsolver directory - Add setup_diago_params_pw() function for PW diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_ks_pw.cpp - Encapsulate diagonalization parameter setup logic Modified files: - source_esolver/esolver_ks_pw.cpp: replace inline code with function call - source_hsolver/CMakeLists.txt: add new source file New files: - source_hsolver/diago_params.h: function declaration - source_hsolver/diago_params.cpp: function implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. * fix(deltaspin): add sc_mag_switch check in cal_mi_lcao_wrapper - Add Input_para parameter to cal_mi_lcao_wrapper function - Add sc_mag_switch check to avoid calling cal_mi_lcao when DeltaSpin is disabled - Fix 'atomCounts is not set' error in non-DeltaSpin calculations - Update function call in esolver_ks_lcao.cpp This fix resolves the CI/CD failure caused by commit 2a520e3. The root cause was that cal_mi_lcao_wrapper was called without checking sc_mag_switch, leading to uninitialized atomCounts error. Modified files: - source_esolver/esolver_ks_lcao.cpp: update function call - source_lcao/module_deltaspin/deltaspin_lcao.h: add parameter - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add check This follows the refactoring principle: preserve original condition checks when extracting code to wrapper functions. * fix(deltaspin): add #ifdef __LCAO for conditional compilation - Add #ifdef __LCAO conditional compilation in init_deltaspin_lcao and cal_mi_lcao_wrapper - Fix parameter order in init_sc call for LCAO and non-LCAO builds - Fix undefined reference to cal_mi_lcao in non-LCAO build This fix resolves CI/CD compilation errors in both build_5pt (with __LCAO) and build_1p (without __LCAO) environments. The The root cause was 1. init_sc has different parameter order in LCAO vs non-LCAO builds - LCAO: psi, dm, pelec - non-LCAO: psi, pelec 2. cal_mi_lcao is only defined in LCAO build Modified files: - source_hsolver/diago_params.h: add setup_diago_params_sdft declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add conditional compilation This follows the refactoring principle: handle conditional compilation properly when code has different implementations for different build configurations. * refactor(esolver): extract SDFT diagonalization parameters setup - Add setup_diago_params_sdft() function for SDFT diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_sdft_pw.cpp - Encapsulate diagonalization parameter setup logic for SDFT Modified files: - source_esolver/esolver_sdft_pw.cpp: replace inline code with function call - source_hsolver/diago_params.cpp: add setup_diago_params_sdft implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. Note: SDFT has different parameter setup logic compared to PW: - Different need_subspace condition - No SCF_ITER setting - Always set PW_DIAG_NMAX (no nscf check) * refactor(hamilt): introduce HamiltBase non-template base class - Create HamiltBase as a non-template base class for Hamilt<T, Device> - Modify Hamilt<T, Device> to inherit from HamiltBase - Change ESolver_KS::p_hamilt type from Hamilt<T, Device>* to HamiltBase* - Add static_cast where needed when passing p_hamilt to functions expecting Hamilt<T, Device>* This is the first step towards removing template parameters from ESolver. Modified files: - source/source_esolver/esolver_ks.h - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_hamilt/hamilt.h New files: - source/source_hamilt/hamilt_base.h * refactor(esolver): add static_cast for p_hamilt in esolver files - Add static_cast<hamilt::Hamilt<T>*> when passing p_hamilt to functions expecting Hamilt<T, Device>* type - Split long cast statements into multiple lines for better readability - Files modified: - esolver_ks_pw.cpp: setup_pot, stp.init calls - esolver_ks_lcao.cpp: init_chg_hr, hsolver_lcao_obj.solve calls - esolver_ks_lcao_tddft.cpp: solve_psi, cal_edm_tddft, matrix calls - esolver_gets.cpp: ops access, output_SR call This follows the HamiltBase refactoring strategy where p_hamilt is stored as HamiltBase* and cast to Hamilt<T, Device>* when needed. --------- Co-authored-by: abacus_fixer <mohanchen@pku.eud.cn>
1 parent 58cc834 commit 8a09b8a

9 files changed

Lines changed: 87 additions & 23 deletions

File tree

source/source_esolver/esolver_gets.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep)
108108
this->kv,
109109
*(two_center_bundle_.overlap_orb),
110110
orb_.cutoffs());
111-
dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, std::complex<double>>*>(this->p_hamilt->ops)
112-
->contributeHR();
111+
auto* hamilt_ptr = static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt);
112+
auto* ops_ptr = dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, std::complex<double>>*>(hamilt_ptr->ops);
113+
ops_ptr->contributeHR();
113114
}
114115
else
115116
{
@@ -119,13 +120,16 @@ void ESolver_GetS::runner(UnitCell& ucell, const int istep)
119120
this->kv,
120121
*(two_center_bundle_.overlap_orb),
121122
orb_.cutoffs());
122-
dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, double>*>(this->p_hamilt->ops)->contributeHR();
123+
auto* hamilt_ptr = static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt);
124+
auto* ops_ptr = dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, double>*>(hamilt_ptr->ops);
125+
ops_ptr->contributeHR();
123126
}
124127
}
125128

126129
const std::string fn = PARAM.globalv.global_out_dir + "sr_nao.csr";
127130

128-
ModuleIO::output_SR(pv, gd, this->p_hamilt, fn);
131+
auto* hamilt_ptr = static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt);
132+
ModuleIO::output_SR(pv, gd, hamilt_ptr, fn);
129133

130134
if (PARAM.inp.out_mat_r)
131135
{

source/source_esolver/esolver_ks.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "source_estate/module_charge/charge_mixing.h" // use charge mixing
88
#include "source_psi/psi.h" // use electronic wave functions
99
#include "source_hamilt/hamilt.h" // use Hamiltonian
10+
#include "source_hamilt/hamilt_base.h" // use Hamiltonian base class
1011
#include "source_lcao/module_dftu/dftu.h" // mohan add 20251107
1112

1213
namespace ModuleESolver
@@ -47,8 +48,8 @@ class ESolver_KS : public ESolver_FP
4748
//! Something to do after SCF iterations when SCF is converged or comes to the max iter step.
4849
virtual void after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) override;
4950

50-
//! Hamiltonian
51-
hamilt::Hamilt<T, Device>* p_hamilt = nullptr;
51+
//! Hamiltonian (base class pointer, actual type determined at runtime)
52+
hamilt::HamiltBase* p_hamilt = nullptr;
5253

5354
//! PW for wave functions, only used in KSDFT, not in OFDFT
5455
ModulePW::PW_Basis_K* pw_wfc = nullptr;

source/source_esolver/esolver_ks_lcao.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
179179
{
180180
//! 13.1.2) init charge density from Hamiltonian matrix file
181181
LCAO_domain::init_chg_hr<TK, TR>(PARAM.globalv.global_readin_dir, PARAM.inp.nspin,
182-
this->p_hamilt, ucell, &(this->pv), this->psi[0], this->pelec, *this->dmat.dm,
182+
static_cast<hamilt::Hamilt<TK>*>(this->p_hamilt), ucell, &(this->pv), this->psi[0], this->pelec, *this->dmat.dm,
183183
this->chr, PARAM.inp.ks_solver);
184184
}
185185
}
@@ -382,7 +382,7 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2rho_single(UnitCell& ucell, int istep, int
382382
if (!skip_solve)
383383
{
384384
hsolver::HSolverLCAO<TK> hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver);
385-
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, *this->dmat.dm,
385+
hsolver_lcao_obj.solve(static_cast<hamilt::Hamilt<TK>*>(this->p_hamilt), this->psi[0], this->pelec, *this->dmat.dm,
386386
this->chr, PARAM.inp.nspin, skip_charge);
387387
}
388388

source/source_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
235235
PARAM.inp.nbands,
236236
PARAM.globalv.nlocal,
237237
this->kv.get_nks(),
238-
this->p_hamilt,
238+
static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt),
239239
this->pv,
240240
this->psi,
241241
this->psi_laststep,
@@ -255,7 +255,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
255255
PARAM.inp.nbands,
256256
PARAM.globalv.nlocal,
257257
this->kv.get_nks(),
258-
this->p_hamilt,
258+
static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt),
259259
this->pv,
260260
this->psi,
261261
this->psi_laststep,
@@ -277,7 +277,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
277277
{
278278
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
279279
hsolver::HSolverLCAO<std::complex<double>> hsolver_lcao_obj(&this->pv, PARAM.inp.ks_solver);
280-
hsolver_lcao_obj.solve(this->p_hamilt,
280+
hsolver_lcao_obj.solve(static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt),
281281
this->psi[0],
282282
this->pelec,
283283
*this->dmat.dm,
@@ -342,11 +342,11 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::iter_finish(UnitCell& ucell,
342342
{
343343
if (use_tensor && use_lapack)
344344
{
345-
elecstate::cal_edm_tddft_tensor_lapack<Device>(this->pv, this->dmat, this->kv, this->p_hamilt);
345+
elecstate::cal_edm_tddft_tensor_lapack<Device>(this->pv, this->dmat, this->kv, static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt));
346346
}
347347
else
348348
{
349-
elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, this->p_hamilt);
349+
elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt));
350350
}
351351
}
352352
}
@@ -416,7 +416,7 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::store_h_s_psi(UnitCell& ucell,
416416
this->p_hamilt->updateHk(ik);
417417
hamilt::MatrixBlock<std::complex<double>> h_mat;
418418
hamilt::MatrixBlock<std::complex<double>> s_mat;
419-
this->p_hamilt->matrix(h_mat, s_mat);
419+
static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt)->matrix(h_mat, s_mat);
420420

421421
// Store H and S matrices to Hk_laststep and Sk_laststep
422422
if (use_tensor && use_lapack)

source/source_esolver/esolver_ks_lcaopw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ namespace ModuleESolver
146146
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
147147

148148
hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
149-
hsolver_lip_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec,
149+
hsolver_lip_obj.solve(static_cast<hamilt::Hamilt<T>*>(this->p_hamilt), this->stp.psi_t[0], this->pelec,
150150
*this->psi_local, skip_charge,ucell.tpiba,ucell.nat);
151151

152152
// add exx

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
128128
// init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06
129129
pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid,
130130
this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell,
131-
this->stp.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp);
131+
this->stp.psi_t, static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp);
132132

133133
// setup psi (electronic wave functions)
134-
this->stp.init(this->p_hamilt);
134+
this->stp.init(static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt));
135135

136136
//! Setup EXX helper for Hamiltonian and psi
137137
exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp);
@@ -188,7 +188,7 @@ void ESolver_KS_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, const int iste
188188
hsolver::DiagoIterAssist<T, Device>::need_subspace,
189189
PARAM.inp.use_k_continuity);
190190

191-
hsolver_pw_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec, this->pelec->ekb.c,
191+
hsolver_pw_obj.solve(static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, this->pelec->ekb.c,
192192
GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat);
193193
}
194194

source/source_esolver/esolver_sdft_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, int istep, i
167167
hsolver::DiagoIterAssist<T, Device>::need_subspace);
168168

169169
hsolver_pw_sdft_obj.solve(ucell,
170-
this->p_hamilt,
170+
static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt),
171171
this->stp.psi_t[0],
172172
this->stp.psi_cpu[0],
173173
this->pelec,
@@ -291,7 +291,7 @@ void ESolver_SDFT_PW<T, Device>::after_all_runners(UnitCell& ucell)
291291
this->pw_wfc,
292292
this->stp.psi_t,
293293
&this->ppcell,
294-
this->p_hamilt,
294+
static_cast<hamilt::Hamilt<std::complex<double>, Device>*>(this->p_hamilt),
295295
this->stoche,
296296
&stowf);
297297
sto_elecond.decide_nche(PARAM.inp.cond_dt, 1e-8, this->nche_sto, PARAM.inp.emin_sto, PARAM.inp.emax_sto);

source/source_hamilt/hamilt.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,28 @@
77
#include "matrixblock.h"
88
#include "source_psi/psi.h"
99
#include "operator.h"
10+
#include "hamilt_base.h"
1011

1112
namespace hamilt
1213
{
1314

1415
template <typename T, typename Device = base_device::DEVICE_CPU>
15-
class Hamilt
16+
class Hamilt : public HamiltBase
1617
{
1718
public:
1819
virtual ~Hamilt(){};
1920

2021
/// for target K point, update consequence of hPsi() and matrix()
21-
virtual void updateHk(const int ik){return;}
22+
void updateHk(const int ik) override { return; }
2223

2324
/// refresh status of Hamiltonian, for example, refresh H(R) and S(R) in LCAO case
24-
virtual void refresh(bool yes = true){return;}
25+
void refresh(bool yes = true) override { return; }
26+
27+
/// get the class name
28+
std::string get_classname() const override { return classname; }
29+
30+
/// get the operator chain
31+
void* get_ops() override { return static_cast<void*>(ops); }
2532

2633
/// core function: for solving eigenvalues of Hamiltonian with iterative method
2734
virtual void hPsi(

source/source_hamilt/hamilt_base.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#ifndef HAMILT_BASE_H
2+
#define HAMILT_BASE_H
3+
4+
#include <string>
5+
6+
namespace hamilt
7+
{
8+
9+
/**
10+
* @brief Base class for Hamiltonian
11+
*
12+
* This is a non-template base class for Hamilt<T, Device>.
13+
* It provides a common interface for all Hamiltonian types,
14+
* allowing ESolver to manage Hamiltonian without template parameters.
15+
*/
16+
class HamiltBase
17+
{
18+
public:
19+
virtual ~HamiltBase() {}
20+
21+
/**
22+
* @brief Update Hamiltonian for a specific k-point
23+
*
24+
* @param ik k-point index
25+
*/
26+
virtual void updateHk(const int ik) { return; }
27+
28+
/**
29+
* @brief Refresh the status of Hamiltonian
30+
*
31+
* @param yes whether to refresh
32+
*/
33+
virtual void refresh(bool yes = true) { return; }
34+
35+
/**
36+
* @brief Get the class name
37+
*
38+
* @return class name
39+
*/
40+
virtual std::string get_classname() const { return "none"; }
41+
42+
/**
43+
* @brief Get the operator chain (as void* to avoid template)
44+
*
45+
* @return pointer to operator chain
46+
*/
47+
virtual void* get_ops() { return nullptr; }
48+
};
49+
50+
} // namespace hamilt
51+
52+
#endif // HAMILT_BASE_H

0 commit comments

Comments
 (0)