Skip to content

Commit e94d753

Browse files
mohanchenabacus_fixer
andauthored
Remove template of ESolver_KS_pw (#7024)
* small format changes * refactor(esolver): extract charge density symmetrization to Symmetry_rho::symmetrize_rho - Add static method symmetrize_rho() in Symmetry_rho class - Replace 7 duplicate code blocks with single function call - Simplify code from 35 lines to 7 lines (80% reduction) - Improve code readability and maintainability Modified files: - source_estate/module_charge/symmetry_rho.h: add static method declaration - source_estate/module_charge/symmetry_rho.cpp: implement static method - source_esolver/esolver_ks_lcao.cpp: 2 calls updated - source_esolver/esolver_ks_pw.cpp: 1 call updated - source_esolver/esolver_ks_lcao_tddft.cpp: 1 call updated - source_esolver/esolver_ks_lcaopw.cpp: 1 call updated - source_esolver/esolver_of.cpp: 1 call updated - source_esolver/esolver_sdft_pw.cpp: 1 call updated This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract DeltaSpin lambda loop to deltaspin_lcao module - Create new files deltaspin_lcao.h/cpp in module_deltaspin - Extract DeltaSpin lambda loop logic from ESolver_KS_LCAO - Simplify code from 18 lines to 1 line in hamilt2rho_single - Separate LCAO and PW implementations for DeltaSpin Modified files: - source_esolver/esolver_ks_lcao.cpp: replace inline code with function call - source_lcao/module_deltaspin/CMakeLists.txt: add new source file New files: - source_lcao/module_deltaspin/deltaspin_lcao.h: function declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: function implementation This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): complete DeltaSpin refactoring in LCAO - Add init_deltaspin_lcao() function for DeltaSpin initialization - Add cal_mi_lcao_wrapper() function for magnetic moment calculation - Refactor all DeltaSpin-related code in esolver_ks_lcao.cpp - Simplify code from 29 lines to 3 lines (90% reduction) Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 3 code blocks with function calls - source_lcao/module_deltaspin/deltaspin_lcao.h: add 2 new function declarations - source_lcao/module_deltaspin/deltaspin_lcao.cpp: implement 2 new functions This completes the DeltaSpin refactoring for LCAO method: 1. init_deltaspin_lcao() - initialize DeltaSpin calculation 2. cal_mi_lcao_wrapper() - calculate magnetic moments 3. run_deltaspin_lambda_loop_lcao() - run lambda loop optimization All functions follow the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract DFT+U code to dftu_lcao module - Create new files dftu_lcao.h/cpp in source_lcao directory - Add init_dftu_lcao() function for DFT+U initialization - Add finish_dftu_lcao() function for DFT+U finalization - Simplify code from 32 lines to 2 lines in esolver_ks_lcao.cpp - Remove conditional checks from ESolver, move them to functions Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 2 code blocks with function calls - source_lcao/CMakeLists.txt: add new source file New files: - source_lcao/dftu_lcao.h: function declarations - source_lcao/dftu_lcao.cpp: function implementations This refactoring prepares for unifying old and new DFT+U implementations: - Old DFT+U: source_lcao/module_dftu/ - New DFT+U: source_lcao/module_operator_lcao/op_dftu_lcao.cpp All functions follow ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract diagonalization parameters setup to hsolver module - Create new files diago_params.h/cpp in source_hsolver directory - Add setup_diago_params_pw() function for PW diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_ks_pw.cpp - Encapsulate diagonalization parameter setup logic Modified files: - source_esolver/esolver_ks_pw.cpp: replace inline code with function call - source_hsolver/CMakeLists.txt: add new source file New files: - source_hsolver/diago_params.h: function declaration - source_hsolver/diago_params.cpp: function implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. * fix(deltaspin): add sc_mag_switch check in cal_mi_lcao_wrapper - Add Input_para parameter to cal_mi_lcao_wrapper function - Add sc_mag_switch check to avoid calling cal_mi_lcao when DeltaSpin is disabled - Fix 'atomCounts is not set' error in non-DeltaSpin calculations - Update function call in esolver_ks_lcao.cpp This fix resolves the CI/CD failure caused by commit 2a520e3. The root cause was that cal_mi_lcao_wrapper was called without checking sc_mag_switch, leading to uninitialized atomCounts error. Modified files: - source_esolver/esolver_ks_lcao.cpp: update function call - source_lcao/module_deltaspin/deltaspin_lcao.h: add parameter - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add check This follows the refactoring principle: preserve original condition checks when extracting code to wrapper functions. * fix(deltaspin): add #ifdef __LCAO for conditional compilation - Add #ifdef __LCAO conditional compilation in init_deltaspin_lcao and cal_mi_lcao_wrapper - Fix parameter order in init_sc call for LCAO and non-LCAO builds - Fix undefined reference to cal_mi_lcao in non-LCAO build This fix resolves CI/CD compilation errors in both build_5pt (with __LCAO) and build_1p (without __LCAO) environments. The The root cause was 1. init_sc has different parameter order in LCAO vs non-LCAO builds - LCAO: psi, dm, pelec - non-LCAO: psi, pelec 2. cal_mi_lcao is only defined in LCAO build Modified files: - source_hsolver/diago_params.h: add setup_diago_params_sdft declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add conditional compilation This follows the refactoring principle: handle conditional compilation properly when code has different implementations for different build configurations. * refactor(esolver): extract SDFT diagonalization parameters setup - Add setup_diago_params_sdft() function for SDFT diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_sdft_pw.cpp - Encapsulate diagonalization parameter setup logic for SDFT Modified files: - source_esolver/esolver_sdft_pw.cpp: replace inline code with function call - source_hsolver/diago_params.cpp: add setup_diago_params_sdft implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. Note: SDFT has different parameter setup logic compared to PW: - Different need_subspace condition - No SCF_ITER setting - Always set PW_DIAG_NMAX (no nscf check) * refactor(hamilt): introduce HamiltBase non-template base class - Create HamiltBase as a non-template base class for Hamilt<T, Device> - Modify Hamilt<T, Device> to inherit from HamiltBase - Change ESolver_KS::p_hamilt type from Hamilt<T, Device>* to HamiltBase* - Add static_cast where needed when passing p_hamilt to functions expecting Hamilt<T, Device>* This is the first step towards removing template parameters from ESolver. Modified files: - source/source_esolver/esolver_ks.h - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_hamilt/hamilt.h New files: - source/source_hamilt/hamilt_base.h * refactor(esolver): add static_cast for p_hamilt in esolver files - Add static_cast<hamilt::Hamilt<T>*> when passing p_hamilt to functions expecting Hamilt<T, Device>* type - Split long cast statements into multiple lines for better readability - Files modified: - esolver_ks_pw.cpp: setup_pot, stp.init calls - esolver_ks_lcao.cpp: init_chg_hr, hsolver_lcao_obj.solve calls - esolver_ks_lcao_tddft.cpp: solve_psi, cal_edm_tddft, matrix calls - esolver_gets.cpp: ops access, output_SR call This follows the HamiltBase refactoring strategy where p_hamilt is stored as HamiltBase* and cast to Hamilt<T, Device>* when needed. * refactor(esolver): remove psi member from ESolver_KS base class Move psi::Psi<T>* psi from ESolver_KS base class to derived classes to eliminate template parameter dependency and improve code organization. Changes: 1. ESolver_KS base class: - Remove psi::Psi<T>* psi member variable - Remove Setup_Psi<T>::deallocate_psi() call in destructor - Remove unnecessary includes: psi.h and setup_psi.h 2. ESolver_KS_LCAO: - Add psi::Psi<TK>* psi member variable - Add Setup_Psi<TK>::deallocate_psi() in destructor - Add include: setup_psi.h 3. ESolver_KS_LCAO_TDDFT: - Improve psi_laststep deallocation with nullptr check - psi member inherited from ESolver_KS_LCAO 4. ESolver_KS_PW: - Use stp.psi_cpu directly instead of base class psi - Remove unnecessary memory allocation in after_scf() 5. pw_others.cpp (BUG FIX): - Fix gen_bessel: use *(this->stp.psi_cpu) instead of this->psi[0] - Previous code accessed uninitialized base class psi (nullptr) - This was a latent bug that could cause crashes Benefits: - Eliminates template parameter T dependency in ESolver_KS base class - Clearer memory management: each derived class manages its own psi - Reduces compilation dependencies - Fixes potential memory access bug in pw_others.cpp Tested: Compiled successfully in build_5pt and build_1p * refactor(esolver): remove template parameters from ESolver_KS base class This is a major milestone in ESolver refactoring! ESolver_KS no longer needs template parameters because: - All member variables are non-template types - All member functions do not use T or Device parameters - Template parameters were only needed for derived classes Changes: 1. ESolver_KS base class: - Remove template <typename T, typename Device> declaration - Remove all template declarations from member functions - Remove template instantiation code at end of file - Fix Tab indentation to spaces for better readability 2. Derived classes: - ESolver_KS_PW: public ESolver_KS (was ESolver_KS<T, Device>) - ESolver_KS_LCAO: public ESolver_KS (was ESolver_KS<TK>) - ESolver_GetS: public ESolver_KS (was ESolver_KS<std::complex<double>>) - Update base class calls: ESolver_KS:: (was ESolver_KS<T, Device>::) Code reduction: - esolver_ks.h: 78 -> 77 lines (-1 line) - esolver_ks.cpp: 346 -> 317 lines (-29 lines) - Total ESolver code: 424 -> 394 lines (-30 lines) - Overall: 8 files changed, 50 insertions(+), 80 deletions(-), net -30 lines Benefits: - Simpler base class without template complexity - Faster compilation (no template instantiation needed) - Clearer inheritance hierarchy - Easier to extract common code in future refactoring - Sets foundation for further ESolver template removal Tested: Compiled successfully in build_5pt * refactor(device): remove explicit template parameter from get_device_type calls - Move get_device_type implementation to header file using std::is_same - Add DEVICE_DSP support - Remove template specialization declarations and definitions - Update all call sites to use automatic template parameter deduction - The compiler now deduces Device type from the ctx parameter * refactor(esolver): remove device member variable from ESolver_KS_PW - Modify copy_d2h to accept ctx parameter and call get_device_type internally - Remove device parameter from ctrl_scf_pw function - Remove device member variable from ESolver_KS_PW class - Simplify function interfaces by using automatic template deduction * style(esolver): explicitly initialize ctx to nullptr in constructor * feat(device): add runtime device type support to DeviceContext - Add device_type_ member variable to DeviceContext class - Add set_device_type() and get_device_type() methods - Add is_cpu(), is_gpu(), is_dsp() convenience methods - Add get_device_type(const DeviceContext*) overload for runtime device type query - Maintain backward compatibility with existing template-based get_device_type * feat(device): add runtime device context overloads for gradual migration - Add copy_d2h(const DeviceContext*) overload to Setup_Psi_pw - Add ctrl_scf_pw(..., const DeviceContext*, ...) overload - Add ctrl_runner_pw(..., const DeviceContext*, ...) overload - Keep original functions for backward compatibility - Replace tabs with spaces in modified files * refactor(esolver): remove ctx member variable from ESolver_KS_PW - Remove Device* ctx member variable from ESolver_KS_PW class - Remove ctx parameter from ctrl_scf_pw and ctrl_runner_pw functions - Add local ctx variable inside ctrl_scf_pw and ctrl_runner_pw functions - Update all template instantiations to match new function signatures This refactoring simplifies the code by moving the ctx variable from a class member to a local variable within the functions that need it. The ctx variable is only used for template parameter deduction in copy_d2h and get_pchg_pw/get_wf_pw functions, so it doesn't need to be stored as a member variable. * refactor(psi): add runtime type information to Setup_Psi_pw - Add runtime type information (device_type_ and precision_type_) to Setup_Psi_pw - Add accessor functions for basic information (get_nbands, get_nk, get_nbasis, size) - Add accessor functions for runtime type information - Add get_psi_t() function for backward compatibility This is the first step of a gradual refactoring to remove template parameters from Setup_Psi_pw in the future. The current changes are backward compatible and do not affect existing functionality. * refactor(esolver): use get_psi_t() accessor instead of direct psi_t access - Replace all direct access to stp.psi_t with stp.get_psi_t() - Replace stp.psi_t->get_nbands() with stp.get_nbands() - This is the second step of gradual refactoring to prepare for removing template parameters Modified files: - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_io/module_ctrl/ctrl_output_pw.cpp * refactor(psi): change psi_t from template pointer to void* - Change psi_t from psi::Psi<T, Device>* to void* - Add static_cast in get_psi_t() function for type conversion - Update all functions that use psi_t to use get_psi_t() or static_cast - This is the third step of gradual refactoring to remove template parameters Modified functions: - before_runner: use if-else instead of ternary operator for void* assignment - update_psi_d: use get_psi_t() to access psi_t - init: use get_psi_t() to access psi_t - copy_d2h: use get_psi_t() to access psi_t - clean: use get_psi_t() to delete psi_t * style: replace Chinese comments with English in setup_psi_pw.h - Replace '原来的模板版本' with 'Original template version' - Replace '使用 void* 存储指针,运行时类型信息记录实际类型' with 'Use void* to store pointer, runtime type information records actual type' - Follow ABACUS code style guidelines for English-only comments * refactor(psi): change psi_d from template pointer to void* - Change psi_d from psi::Psi<std::complex<double>, Device>* to void* - Add get_psi_d() accessor function for type conversion - Update all functions that use psi_d to use get_psi_d() - This is part of step 1 in phase 4 of gradual refactoring Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_io/module_ctrl/ctrl_output_pw.cpp * refactor(psi): introduce PSIPrepareBase base class for template removal This is the first step towards removing template parameters from Setup_Psi_pw. Changes: 1. Create PSIPrepareBase base class - Non-template base class for PSIPrepare<T, Device> - Similar approach to HamiltBase for Hamilt<T, Device> 2. Modify PSIPrepare to inherit from PSIPrepareBase - Add #include "source_psi/psi_prepare_base.h" - Change class declaration to inherit from PSIPrepareBase 3. Update Setup_Psi_pw to use PSIPrepareBase* - Change p_psi_init from PSIPrepare<T, Device>* to PSIPrepareBase* - Add static_cast when calling PSIPrepare methods 4. Update all PSIPrepare usage in ESolver files - esolver_ks_pw.cpp: add static_cast before prepare_init call - esolver_ks_lcaopw.cpp: add static_cast before method calls Modified files: - source/source_psi/psi_prepare_base.h (new) - source/source_psi/psi_prepare.h - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_ks_lcaopw.cpp Benefits: - Eliminates p_psi_init template dependency from Setup_Psi_pw - Paves the way for removing template parameters from Setup_Psi_pw - Maintains type safety through static_cast - Follows the same pattern as HamiltBase refactoring Tested: Compiled successfully in build_5pt and build_1p * refactor(psi): change init() parameter from template to HamiltBase* This is the second step towards removing template parameters from Setup_Psi_pw. Changes: 1. Modify init() function signature - Change parameter from hamilt::Hamilt<T, Device>* to hamilt::HamiltBase* - Eliminates template dependency in function signature 2. Update init() implementation - Add static_cast<hamilt::Hamilt<T, Device>*> inside function - Maintain type safety through explicit cast 3. Update call site in esolver_ks_pw.cpp - Remove static_cast from call site - Directly pass p_hamilt (which is already HamiltBase*) Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.cpp Benefits: - init() function no longer depends on template parameters in signature - Simplifies call sites (no cast needed) - Follows the same pattern as p_hamilt storage in ESolver_KS - One step closer to removing template parameters from Setup_Psi_pw Tested: Compiled successfully in build_5pt and build_1p * refactor(psi): remove template version of copy_d2h function This is the third step towards removing template parameters from Setup_Psi_pw. Changes: 1. Remove template version copy_d2h(const Device* ctx) - Delete the template version from setup_psi_pw.h - Delete the implementation from setup_psi_pw.cpp - Keep only the runtime version copy_d2h(const base_device::DeviceContext* ctx) 2. Update call site in ctrl_output_pw.cpp - Create Device* ctx = nullptr for template parameter deduction - Use DeviceContext::instance() for runtime device context - Call copy_d2h with DeviceContext* pointer Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_io/module_ctrl/ctrl_output_pw.cpp Benefits: - Eliminates copy_d2h function's template dependency - All member functions now use runtime device context - One step closer to removing template parameters from Setup_Psi_pw - Maintains backward compatibility with existing code Technical details: - get_pchg_pw and get_wf_pw still need Device* ctx for template deduction - DeviceContext is used for actual device type information - This follows the gradual migration pattern used in ESolver refactoring Tested: Compiled successfully in build_5pt and build_1p * refactor(psi): remove castmem_2d_d2h_op template type alias dependency This is the fourth step towards removing template parameters from Setup_Psi_pw. Changes: 1. Remove castmem_2d_d2h_op type alias from setup_psi_pw.h - The type alias depended on template parameters T and Device - Replaced with overloaded member functions 2. Add castmem_d2h_impl() overloaded functions - One overload for std::complex<double> source - One overload for std::complex<float> source - Each uses the appropriate cast_memory_op internally 3. Update copy_d2h() to use the new overloaded functions - Calls castmem_d2h_impl() instead of castmem_2d_d2h_op() - Compiler selects the correct overload based on T type Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp Benefits: - All member variables now independent of template parameters - castmem_d2h_impl encapsulates the type-dependent logic - One step closer to removing template parameters from Setup_Psi_pw Tested: Compiled successfully in build_5pt and build_1p * delete useless files --------- Co-authored-by: abacus_fixer <mohanchen@pku.eud.cn>
1 parent 8efe9f5 commit e94d753

10 files changed

Lines changed: 171 additions & 109 deletions

File tree

source/source_esolver/esolver_ks_lcaopw.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,18 @@ namespace ModuleESolver
8181
void ESolver_KS_LIP<T>::before_scf(UnitCell& ucell, const int istep)
8282
{
8383
ESolver_KS_PW<T>::before_scf(ucell, istep);
84-
this->stp.p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running);
84+
auto* p_psi_init = static_cast<psi::PSIPrepare<T>*>(this->stp.p_psi_init);
85+
p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running);
8586
}
8687

8788
template <typename T>
8889
void ESolver_KS_LIP<T>::before_all_runners(UnitCell& ucell, const Input_para& inp)
8990
{
9091
ESolver_KS_PW<T>::before_all_runners(ucell, inp);
92+
auto* p_psi_init = static_cast<psi::PSIPrepare<T>*>(this->stp.p_psi_init);
9193
delete this->psi_local;
9294
this->psi_local = new psi::Psi<T>(this->stp.psi_cpu->get_nk(),
93-
this->stp.p_psi_init->psi_initer->nbands_start(),
95+
p_psi_init->psi_initer->nbands_start(),
9496
this->stp.psi_cpu->get_nbasis(),
9597
this->kv.ngk,
9698
true);
@@ -105,7 +107,7 @@ namespace ModuleESolver
105107
ucell.symm,
106108
&this->kv,
107109
this->psi_local,
108-
this->stp.psi_t,
110+
this->stp.get_psi_t(),
109111
this->pw_wfc,
110112
this->pw_rho,
111113
this->sf,
@@ -146,7 +148,7 @@ namespace ModuleESolver
146148
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
147149

148150
hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
149-
hsolver_lip_obj.solve(static_cast<hamilt::Hamilt<T>*>(this->p_hamilt), this->stp.psi_t[0], this->pelec,
151+
hsolver_lip_obj.solve(static_cast<hamilt::Hamilt<T>*>(this->p_hamilt), *this->stp.get_psi_t(), this->pelec,
150152
*this->psi_local, skip_charge,ucell.tpiba,ucell.nat);
151153

152154
// add exx
@@ -240,7 +242,7 @@ namespace ModuleESolver
240242
ModuleIO::write_Vxc(PARAM.inp.nspin,
241243
PARAM.globalv.nlocal,
242244
GlobalV::DRANK,
243-
*this->stp.psi_t,
245+
*this->stp.get_psi_t(),
244246
ucell,
245247
this->sf,
246248
this->solvent,

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
4343
{
4444
this->classname = "ESolver_KS_PW";
4545
this->basisname = "PW";
46-
this->ctx = nullptr;
4746
}
4847

4948
template <typename T, typename Device>
@@ -112,7 +111,8 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
112111

113112
if (ucell.cell_parameter_updated)
114113
{
115-
this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed);
114+
auto* p_psi_init = static_cast<psi::PSIPrepare<T, Device>*>(this->stp.p_psi_init);
115+
p_psi_init->prepare_init(PARAM.inp.pw_seed);
116116
}
117117

118118
//! Init Hamiltonian (cell changed)
@@ -128,13 +128,13 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
128128
// init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06
129129
pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid,
130130
this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell,
131-
this->stp.psi_t, static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp);
131+
this->stp.get_psi_t(), static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp);
132132

133133
// setup psi (electronic wave functions)
134-
this->stp.init(static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt));
134+
this->stp.init(this->p_hamilt);
135135

136136
//! Setup EXX helper for Hamiltonian and psi
137-
exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp);
137+
exx_helper.before_scf(this->p_hamilt, this->stp.get_psi_t(), PARAM.inp);
138138

139139
ModuleBase::timer::tick("ESolver_KS_PW", "before_scf");
140140
}
@@ -152,7 +152,7 @@ void ESolver_KS_PW<T, Device>::iter_init(UnitCell& ucell, const int istep, const
152152

153153
// update local occupations for DFT+U
154154
// should before lambda loop in DeltaSpin
155-
pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp);
155+
pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.get_psi_t(), this->pelec->wg, ucell, PARAM.inp);
156156
}
157157

158158
// Temporary, it should be replaced by hsolver later.
@@ -188,7 +188,7 @@ void ESolver_KS_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, const int iste
188188
hsolver::DiagoIterAssist<T, Device>::need_subspace,
189189
PARAM.inp.use_k_continuity);
190190

191-
hsolver_pw_obj.solve(static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), this->stp.psi_t[0], this->pelec, this->pelec->ekb.c,
191+
hsolver_pw_obj.solve(static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), *this->stp.get_psi_t(), this->pelec, this->pelec->ekb.c,
192192
GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat);
193193
}
194194

@@ -205,7 +205,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
205205
// Related to EXX
206206
if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter)
207207
{
208-
this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.psi_t));
208+
this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.get_psi_t()));
209209
}
210210

211211
// deband is calculated from "output" charge density
@@ -224,7 +224,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
224224
}
225225

226226
// Handle EXX-related operations after SCF iteration
227-
exx_helper.iter_finish(this->pelec, &this->chr, this->stp.psi_t, ucell, PARAM.inp, conv_esolver, iter);
227+
exx_helper.iter_finish(this->pelec, &this->chr, this->stp.get_psi_t(), ucell, PARAM.inp, conv_esolver, iter);
228228

229229
// check if oscillate for delta_spin method
230230
pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp);
@@ -251,7 +251,7 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
251251
// Output quantities
252252
ModuleIO::ctrl_scf_pw<T, Device>(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc,
253253
this->pw_rho, this->pw_rhod, this->pw_big, this->stp,
254-
this->ctx, this->Pgrid, PARAM.inp);
254+
this->Pgrid, PARAM.inp);
255255

256256
ModuleBase::timer::tick("ESolver_KS_PW", "after_scf");
257257
}
@@ -273,7 +273,7 @@ void ESolver_KS_PW<T, Device>::cal_force(UnitCell& ucell, ModuleBase::matrix& fo
273273
// Calculate forces
274274
ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm,
275275
&this->sf, this->solvent, &this->dftu, &this->locpp, &this->ppcell,
276-
&this->kv, this->pw_wfc, this->stp.psi_d);
276+
&this->kv, this->pw_wfc, this->stp.get_psi_d());
277277
}
278278

279279
template <typename T, typename Device>
@@ -285,7 +285,7 @@ void ESolver_KS_PW<T, Device>::cal_stress(UnitCell& ucell, ModuleBase::matrix& s
285285
this->stp.update_psi_d();
286286

287287
ss.cal_stress(stress, ucell, this->dftu, this->locpp, this->ppcell, this->pw_rhod,
288-
&ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.psi_d);
288+
&ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.get_psi_d());
289289

290290
// external stress
291291
double unit_transform = 0.0;
@@ -304,7 +304,7 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
304304

305305
ModuleIO::ctrl_runner_pw<T, Device>(ucell, this->pelec, this->pw_wfc,
306306
this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp,
307-
this->sf, this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp);
307+
this->sf, this->ppcell, this->solvent, this->Pgrid, PARAM.inp);
308308

309309
elecstate::teardown_estate_pw<T, Device>(this->pelec, this->vsep_cell);
310310

source/source_esolver/esolver_ks_pw.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ class ESolver_KS_PW : public ESolver_KS
5757
// DFT-1/2 method
5858
VSep* vsep_cell = nullptr;
5959

60-
// for get_pchg and get_wf, use ctx as input of fft
61-
Device* ctx = {};
62-
6360
};
6461
} // namespace ModuleESolver
6562
#endif

source/source_esolver/esolver_sdft_pw.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, int istep, i
168168

169169
hsolver_pw_sdft_obj.solve(ucell,
170170
static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt),
171-
this->stp.psi_t[0],
171+
*this->stp.get_psi_t(),
172172
this->stp.psi_cpu[0],
173173
this->pelec,
174174
this->pw_wfc,
@@ -221,7 +221,7 @@ void ESolver_SDFT_PW<T, Device>::cal_force(UnitCell& ucell, ModuleBase::matrix&
221221
this->locpp,
222222
this->ppcell,
223223
ucell,
224-
*this->stp.psi_t,
224+
*this->stp.get_psi_t(),
225225
this->stowf);
226226
}
227227

@@ -236,7 +236,7 @@ void ESolver_SDFT_PW<T, Device>::cal_stress(UnitCell& ucell, ModuleBase::matrix&
236236
&this->sf,
237237
&this->kv,
238238
this->pw_wfc,
239-
*this->stp.psi_t,
239+
*this->stp.get_psi_t(),
240240
this->stowf,
241241
&this->chr,
242242
&this->locpp,
@@ -289,7 +289,7 @@ void ESolver_SDFT_PW<T, Device>::after_all_runners(UnitCell& ucell)
289289
&this->kv,
290290
this->pelec,
291291
this->pw_wfc,
292-
this->stp.psi_t,
292+
this->stp.get_psi_t(),
293293
&this->ppcell,
294294
static_cast<hamilt::Hamilt<std::complex<double>, Device>*>(this->p_hamilt),
295295
this->stoche,

source/source_io/module_ctrl/ctrl_output_pw.cpp

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,19 @@ void ModuleIO::ctrl_scf_pw(const int istep,
9191
const ModulePW::PW_Basis *pw_rhod,
9292
const ModulePW::PW_Basis_Big *pw_big,
9393
Setup_Psi_pw<T, Device> &stp,
94-
const Device* ctx,
9594
const Parallel_Grid &para_grid,
9695
const Input_para& inp)
9796
{
9897
ModuleBase::TITLE("ModuleIO", "ctrl_scf_pw");
9998
ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw");
10099

100+
// Create local ctx for device type deduction
101+
Device* ctx = nullptr;
102+
101103
// Transfer data from device (GPU) to host (CPU) in pw basis
102-
stp.copy_d2h(ctx);
104+
base_device::DeviceContext* device_ctx = &base_device::DeviceContext::instance();
105+
device_ctx->set_device_type(stp.get_device_type());
106+
stp.copy_d2h(device_ctx);
103107

104108
//----------------------------------------------------------
105109
//! 4) Compute density of states (DOS)
@@ -164,7 +168,7 @@ void ModuleIO::ctrl_scf_pw(const int istep,
164168
// update psi_d
165169
stp.update_psi_d();
166170

167-
const int nbands = stp.psi_t->get_nbands();
171+
const int nbands = stp.get_nbands();
168172
const int ngmc = chr.ngmc;
169173

170174
ModuleIO::get_pchg_pw(inp.out_pchg,
@@ -173,7 +177,7 @@ void ModuleIO::ctrl_scf_pw(const int istep,
173177
pw_rhod->nxyz,
174178
ngmc,
175179
&ucell,
176-
stp.psi_d,
180+
stp.get_psi_d(),
177181
pw_rhod,
178182
pw_wfc,
179183
ctx,
@@ -235,7 +239,7 @@ void ModuleIO::ctrl_scf_pw(const int istep,
235239
if (inp.onsite_radius > 0)
236240
{ // float type has not been implemented
237241
auto* onsite_p = projectors::OnsiteProjector<double, Device>::get_instance();
238-
onsite_p->cal_occupations(reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(stp.psi_t),
242+
onsite_p->cal_occupations(reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(stp.get_psi_t()),
239243
pelec->wg);
240244
}
241245

@@ -255,13 +259,15 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell,
255259
Structure_Factor &sf,
256260
pseudopot_cell_vnl &ppcell,
257261
surchem &solvent,
258-
const Device* ctx,
259262
Parallel_Grid &para_grid,
260263
const Input_para& inp)
261264
{
262265
ModuleBase::TITLE("ModuleIO", "ctrl_runner_pw");
263266
ModuleBase::timer::tick("ModuleIO", "ctrl_runner_pw");
264267

268+
// Create local ctx for device type deduction
269+
Device* ctx = nullptr;
270+
265271
//----------------------------------------------------------
266272
//! 1) Compute LDOS
267273
//----------------------------------------------------------
@@ -303,11 +309,11 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell,
303309

304310
ModuleIO::get_wf_pw(inp.out_wfc_norm,
305311
inp.out_wfc_re_im,
306-
stp.psi_t->get_nbands(),
312+
stp.get_nbands(),
307313
inp.nspin,
308314
pw_rhod->nxyz,
309315
&ucell,
310-
stp.psi_d,
316+
stp.get_psi_d(),
311317
pw_wfc,
312318
ctx,
313319
para_grid,
@@ -323,7 +329,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell,
323329
if (inp.cal_cond)
324330
{
325331
using Real = typename GetTypeReal<T>::type;
326-
EleCond<Real, Device> elec_cond(&ucell, &kv, pelec, pw_wfc, stp.psi_t, &ppcell);
332+
EleCond<Real, Device> elec_cond(&ucell, &kv, pelec, pw_wfc, stp.get_psi_t(), &ppcell);
327333
elec_cond.KG(inp.cond_smear,
328334
inp.cond_fwhm,
329335
inp.cond_wcut,
@@ -360,7 +366,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell,
360366
pw_rho);
361367

362368
write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir,
363-
stp.psi_t,
369+
stp.get_psi_t(),
364370
pelec,
365371
pw_wfc,
366372
pw_rho,
@@ -384,7 +390,6 @@ template void ModuleIO::ctrl_scf_pw<std::complex<float>, base_device::DEVICE_CPU
384390
const ModulePW::PW_Basis *pw_rhod,
385391
const ModulePW::PW_Basis_Big *pw_big,
386392
Setup_Psi_pw<std::complex<float>, base_device::DEVICE_CPU> &stp,
387-
const base_device::DEVICE_CPU* ctx,
388393
const Parallel_Grid &para_grid,
389394
const Input_para& inp);
390395

@@ -400,7 +405,6 @@ template void ModuleIO::ctrl_scf_pw<std::complex<double>, base_device::DEVICE_CP
400405
const ModulePW::PW_Basis *pw_rhod,
401406
const ModulePW::PW_Basis_Big *pw_big,
402407
Setup_Psi_pw<std::complex<double>, base_device::DEVICE_CPU> &stp,
403-
const base_device::DEVICE_CPU* ctx,
404408
const Parallel_Grid &para_grid,
405409
const Input_para& inp);
406410

@@ -417,13 +421,12 @@ template void ModuleIO::ctrl_scf_pw<std::complex<float>, base_device::DEVICE_GPU
417421
const ModulePW::PW_Basis *pw_rhod,
418422
const ModulePW::PW_Basis_Big *pw_big,
419423
Setup_Psi_pw<std::complex<float>, base_device::DEVICE_GPU> &stp,
420-
const base_device::DEVICE_GPU* ctx,
421424
const Parallel_Grid &para_grid,
422425
const Input_para& inp);
423426

424427
// complex<double> + GPU
425428
template void ModuleIO::ctrl_scf_pw<std::complex<double>, base_device::DEVICE_GPU>(
426-
const int nstep,
429+
const int nstep,
427430
UnitCell& ucell,
428431
elecstate::ElecState* pelec,
429432
const Charge &chr,
@@ -433,7 +436,6 @@ template void ModuleIO::ctrl_scf_pw<std::complex<double>, base_device::DEVICE_GP
433436
const ModulePW::PW_Basis *pw_rhod,
434437
const ModulePW::PW_Basis_Big *pw_big,
435438
Setup_Psi_pw<std::complex<double>, base_device::DEVICE_GPU> &stp,
436-
const base_device::DEVICE_GPU* ctx,
437439
const Parallel_Grid &para_grid,
438440
const Input_para& inp);
439441
#endif
@@ -444,14 +446,13 @@ template void ModuleIO::ctrl_runner_pw<std::complex<float>, base_device::DEVICE_
444446
elecstate::ElecState* pelec,
445447
ModulePW::PW_Basis_K* pw_wfc,
446448
ModulePW::PW_Basis* pw_rho,
447-
ModulePW::PW_Basis* pw_rhod,
449+
ModulePW::PW_Basis* pw_rhod,
448450
Charge &chr,
449-
K_Vectors &kv,
451+
K_Vectors &kv,
450452
Setup_Psi_pw<std::complex<float>, base_device::DEVICE_CPU> &stp,
451453
Structure_Factor &sf,
452454
pseudopot_cell_vnl &ppcell,
453455
surchem &solvent,
454-
const base_device::DEVICE_CPU* ctx,
455456
Parallel_Grid &para_grid,
456457
const Input_para& inp);
457458

@@ -461,14 +462,13 @@ template void ModuleIO::ctrl_runner_pw<std::complex<double>, base_device::DEVICE
461462
elecstate::ElecState* pelec,
462463
ModulePW::PW_Basis_K* pw_wfc,
463464
ModulePW::PW_Basis* pw_rho,
464-
ModulePW::PW_Basis* pw_rhod,
465+
ModulePW::PW_Basis* pw_rhod,
465466
Charge &chr,
466-
K_Vectors &kv,
467+
K_Vectors &kv,
467468
Setup_Psi_pw<std::complex<double>, base_device::DEVICE_CPU> &stp,
468469
Structure_Factor &sf,
469470
pseudopot_cell_vnl &ppcell,
470471
surchem &solvent,
471-
const base_device::DEVICE_CPU* ctx,
472472
Parallel_Grid &para_grid,
473473
const Input_para& inp);
474474

@@ -481,12 +481,11 @@ template void ModuleIO::ctrl_runner_pw<std::complex<float>, base_device::DEVICE_
481481
ModulePW::PW_Basis* pw_rho,
482482
ModulePW::PW_Basis* pw_rhod,
483483
Charge &chr,
484-
K_Vectors &kv,
484+
K_Vectors &kv,
485485
Setup_Psi_pw<std::complex<float>, base_device::DEVICE_GPU> &stp,
486486
Structure_Factor &sf,
487487
pseudopot_cell_vnl &ppcell,
488488
surchem &solvent,
489-
const base_device::DEVICE_GPU* ctx,
490489
Parallel_Grid &para_grid,
491490
const Input_para& inp);
492491

@@ -498,12 +497,11 @@ template void ModuleIO::ctrl_runner_pw<std::complex<double>, base_device::DEVICE
498497
ModulePW::PW_Basis* pw_rho,
499498
ModulePW::PW_Basis* pw_rhod,
500499
Charge &chr,
501-
K_Vectors &kv,
500+
K_Vectors &kv,
502501
Setup_Psi_pw<std::complex<double>, base_device::DEVICE_GPU> &stp,
503502
Structure_Factor &sf,
504503
pseudopot_cell_vnl &ppcell,
505504
surchem &solvent,
506-
const base_device::DEVICE_GPU* ctx,
507505
Parallel_Grid &para_grid,
508506
const Input_para& inp);
509507
#endif

0 commit comments

Comments
 (0)