Skip to content

Commit 342de19

Browse files
mohanchenabacus_fixer
andauthored
Let's try removing the template of ESolver_KS_pw again (#7028)
* small format changes * refactor(esolver): extract charge density symmetrization to Symmetry_rho::symmetrize_rho - Add static method symmetrize_rho() in Symmetry_rho class - Replace 7 duplicate code blocks with single function call - Simplify code from 35 lines to 7 lines (80% reduction) - Improve code readability and maintainability Modified files: - source_estate/module_charge/symmetry_rho.h: add static method declaration - source_estate/module_charge/symmetry_rho.cpp: implement static method - source_esolver/esolver_ks_lcao.cpp: 2 calls updated - source_esolver/esolver_ks_pw.cpp: 1 call updated - source_esolver/esolver_ks_lcao_tddft.cpp: 1 call updated - source_esolver/esolver_ks_lcaopw.cpp: 1 call updated - source_esolver/esolver_of.cpp: 1 call updated - source_esolver/esolver_sdft_pw.cpp: 1 call updated This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract DeltaSpin lambda loop to deltaspin_lcao module - Create new files deltaspin_lcao.h/cpp in module_deltaspin - Extract DeltaSpin lambda loop logic from ESolver_KS_LCAO - Simplify code from 18 lines to 1 line in hamilt2rho_single - Separate LCAO and PW implementations for DeltaSpin Modified files: - source_esolver/esolver_ks_lcao.cpp: replace inline code with function call - source_lcao/module_deltaspin/CMakeLists.txt: add new source file New files: - source_lcao/module_deltaspin/deltaspin_lcao.h: function declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: function implementation This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): complete DeltaSpin refactoring in LCAO - Add init_deltaspin_lcao() function for DeltaSpin initialization - Add cal_mi_lcao_wrapper() function for magnetic moment calculation - Refactor all DeltaSpin-related code in esolver_ks_lcao.cpp - Simplify code from 29 lines to 3 lines (90% reduction) Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 3 code blocks with function calls - source_lcao/module_deltaspin/deltaspin_lcao.h: add 2 new function declarations - source_lcao/module_deltaspin/deltaspin_lcao.cpp: implement 2 new functions This completes the DeltaSpin refactoring for LCAO method: 1. init_deltaspin_lcao() - initialize DeltaSpin calculation 2. cal_mi_lcao_wrapper() - calculate magnetic moments 3. run_deltaspin_lambda_loop_lcao() - run lambda loop optimization All functions follow the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract DFT+U code to dftu_lcao module - Create new files dftu_lcao.h/cpp in source_lcao directory - Add init_dftu_lcao() function for DFT+U initialization - Add finish_dftu_lcao() function for DFT+U finalization - Simplify code from 32 lines to 2 lines in esolver_ks_lcao.cpp - Remove conditional checks from ESolver, move them to functions Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 2 code blocks with function calls - source_lcao/CMakeLists.txt: add new source file New files: - source_lcao/dftu_lcao.h: function declarations - source_lcao/dftu_lcao.cpp: function implementations This refactoring prepares for unifying old and new DFT+U implementations: - Old DFT+U: source_lcao/module_dftu/ - New DFT+U: source_lcao/module_operator_lcao/op_dftu_lcao.cpp All functions follow ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract diagonalization parameters setup to hsolver module - Create new files diago_params.h/cpp in source_hsolver directory - Add setup_diago_params_pw() function for PW diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_ks_pw.cpp - Encapsulate diagonalization parameter setup logic Modified files: - source_esolver/esolver_ks_pw.cpp: replace inline code with function call - source_hsolver/CMakeLists.txt: add new source file New files: - source_hsolver/diago_params.h: function declaration - source_hsolver/diago_params.cpp: function implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. * fix(deltaspin): add sc_mag_switch check in cal_mi_lcao_wrapper - Add Input_para parameter to cal_mi_lcao_wrapper function - Add sc_mag_switch check to avoid calling cal_mi_lcao when DeltaSpin is disabled - Fix 'atomCounts is not set' error in non-DeltaSpin calculations - Update function call in esolver_ks_lcao.cpp This fix resolves the CI/CD failure caused by commit 2a520e3. The root cause was that cal_mi_lcao_wrapper was called without checking sc_mag_switch, leading to uninitialized atomCounts error. Modified files: - source_esolver/esolver_ks_lcao.cpp: update function call - source_lcao/module_deltaspin/deltaspin_lcao.h: add parameter - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add check This follows the refactoring principle: preserve original condition checks when extracting code to wrapper functions. * fix(deltaspin): add #ifdef __LCAO for conditional compilation - Add #ifdef __LCAO conditional compilation in init_deltaspin_lcao and cal_mi_lcao_wrapper - Fix parameter order in init_sc call for LCAO and non-LCAO builds - Fix undefined reference to cal_mi_lcao in non-LCAO build This fix resolves CI/CD compilation errors in both build_5pt (with __LCAO) and build_1p (without __LCAO) environments. The The root cause was 1. init_sc has different parameter order in LCAO vs non-LCAO builds - LCAO: psi, dm, pelec - non-LCAO: psi, pelec 2. cal_mi_lcao is only defined in LCAO build Modified files: - source_hsolver/diago_params.h: add setup_diago_params_sdft declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add conditional compilation This follows the refactoring principle: handle conditional compilation properly when code has different implementations for different build configurations. * refactor(esolver): extract SDFT diagonalization parameters setup - Add setup_diago_params_sdft() function for SDFT diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_sdft_pw.cpp - Encapsulate diagonalization parameter setup logic for SDFT Modified files: - source_esolver/esolver_sdft_pw.cpp: replace inline code with function call - source_hsolver/diago_params.cpp: add setup_diago_params_sdft implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. Note: SDFT has different parameter setup logic compared to PW: - Different need_subspace condition - No SCF_ITER setting - Always set PW_DIAG_NMAX (no nscf check) * refactor(hamilt): introduce HamiltBase non-template base class - Create HamiltBase as a non-template base class for Hamilt<T, Device> - Modify Hamilt<T, Device> to inherit from HamiltBase - Change ESolver_KS::p_hamilt type from Hamilt<T, Device>* to HamiltBase* - Add static_cast where needed when passing p_hamilt to functions expecting Hamilt<T, Device>* This is the first step towards removing template parameters from ESolver. Modified files: - source/source_esolver/esolver_ks.h - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_hamilt/hamilt.h New files: - source/source_hamilt/hamilt_base.h * refactor(esolver): add static_cast for p_hamilt in esolver files - Add static_cast<hamilt::Hamilt<T>*> when passing p_hamilt to functions expecting Hamilt<T, Device>* type - Split long cast statements into multiple lines for better readability - Files modified: - esolver_ks_pw.cpp: setup_pot, stp.init calls - esolver_ks_lcao.cpp: init_chg_hr, hsolver_lcao_obj.solve calls - esolver_ks_lcao_tddft.cpp: solve_psi, cal_edm_tddft, matrix calls - esolver_gets.cpp: ops access, output_SR call This follows the HamiltBase refactoring strategy where p_hamilt is stored as HamiltBase* and cast to Hamilt<T, Device>* when needed. * refactor(esolver): remove psi member from ESolver_KS base class Move psi::Psi<T>* psi from ESolver_KS base class to derived classes to eliminate template parameter dependency and improve code organization. Changes: 1. ESolver_KS base class: - Remove psi::Psi<T>* psi member variable - Remove Setup_Psi<T>::deallocate_psi() call in destructor - Remove unnecessary includes: psi.h and setup_psi.h 2. ESolver_KS_LCAO: - Add psi::Psi<TK>* psi member variable - Add Setup_Psi<TK>::deallocate_psi() in destructor - Add include: setup_psi.h 3. ESolver_KS_LCAO_TDDFT: - Improve psi_laststep deallocation with nullptr check - psi member inherited from ESolver_KS_LCAO 4. ESolver_KS_PW: - Use stp.psi_cpu directly instead of base class psi - Remove unnecessary memory allocation in after_scf() 5. pw_others.cpp (BUG FIX): - Fix gen_bessel: use *(this->stp.psi_cpu) instead of this->psi[0] - Previous code accessed uninitialized base class psi (nullptr) - This was a latent bug that could cause crashes Benefits: - Eliminates template parameter T dependency in ESolver_KS base class - Clearer memory management: each derived class manages its own psi - Reduces compilation dependencies - Fixes potential memory access bug in pw_others.cpp Tested: Compiled successfully in build_5pt and build_1p * refactor(esolver): remove template parameters from ESolver_KS base class This is a major milestone in ESolver refactoring! ESolver_KS no longer needs template parameters because: - All member variables are non-template types - All member functions do not use T or Device parameters - Template parameters were only needed for derived classes Changes: 1. ESolver_KS base class: - Remove template <typename T, typename Device> declaration - Remove all template declarations from member functions - Remove template instantiation code at end of file - Fix Tab indentation to spaces for better readability 2. Derived classes: - ESolver_KS_PW: public ESolver_KS (was ESolver_KS<T, Device>) - ESolver_KS_LCAO: public ESolver_KS (was ESolver_KS<TK>) - ESolver_GetS: public ESolver_KS (was ESolver_KS<std::complex<double>>) - Update base class calls: ESolver_KS:: (was ESolver_KS<T, Device>::) Code reduction: - esolver_ks.h: 78 -> 77 lines (-1 line) - esolver_ks.cpp: 346 -> 317 lines (-29 lines) - Total ESolver code: 424 -> 394 lines (-30 lines) - Overall: 8 files changed, 50 insertions(+), 80 deletions(-), net -30 lines Benefits: - Simpler base class without template complexity - Faster compilation (no template instantiation needed) - Clearer inheritance hierarchy - Easier to extract common code in future refactoring - Sets foundation for further ESolver template removal Tested: Compiled successfully in build_5pt * refactor(device): remove explicit template parameter from get_device_type calls - Move get_device_type implementation to header file using std::is_same - Add DEVICE_DSP support - Remove template specialization declarations and definitions - Update all call sites to use automatic template parameter deduction - The compiler now deduces Device type from the ctx parameter * refactor(esolver): remove device member variable from ESolver_KS_PW - Modify copy_d2h to accept ctx parameter and call get_device_type internally - Remove device parameter from ctrl_scf_pw function - Remove device member variable from ESolver_KS_PW class - Simplify function interfaces by using automatic template deduction * style(esolver): explicitly initialize ctx to nullptr in constructor * feat(device): add runtime device type support to DeviceContext - Add device_type_ member variable to DeviceContext class - Add set_device_type() and get_device_type() methods - Add is_cpu(), is_gpu(), is_dsp() convenience methods - Add get_device_type(const DeviceContext*) overload for runtime device type query - Maintain backward compatibility with existing template-based get_device_type * feat(device): add runtime device context overloads for gradual migration - Add copy_d2h(const DeviceContext*) overload to Setup_Psi_pw - Add ctrl_scf_pw(..., const DeviceContext*, ...) overload - Add ctrl_runner_pw(..., const DeviceContext*, ...) overload - Keep original functions for backward compatibility - Replace tabs with spaces in modified files * refactor(esolver): remove ctx member variable from ESolver_KS_PW - Remove Device* ctx member variable from ESolver_KS_PW class - Remove ctx parameter from ctrl_scf_pw and ctrl_runner_pw functions - Add local ctx variable inside ctrl_scf_pw and ctrl_runner_pw functions - Update all template instantiations to match new function signatures This refactoring simplifies the code by moving the ctx variable from a class member to a local variable within the functions that need it. The ctx variable is only used for template parameter deduction in copy_d2h and get_pchg_pw/get_wf_pw functions, so it doesn't need to be stored as a member variable. * refactor(psi): add runtime type information to Setup_Psi_pw - Add runtime type information (device_type_ and precision_type_) to Setup_Psi_pw - Add accessor functions for basic information (get_nbands, get_nk, get_nbasis, size) - Add accessor functions for runtime type information - Add get_psi_t() function for backward compatibility This is the first step of a gradual refactoring to remove template parameters from Setup_Psi_pw in the future. The current changes are backward compatible and do not affect existing functionality. * refactor(esolver): use get_psi_t() accessor instead of direct psi_t access - Replace all direct access to stp.psi_t with stp.get_psi_t() - Replace stp.psi_t->get_nbands() with stp.get_nbands() - This is the second step of gradual refactoring to prepare for removing template parameters Modified files: - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_io/module_ctrl/ctrl_output_pw.cpp * refactor(psi): change psi_t from template pointer to void* - Change psi_t from psi::Psi<T, Device>* to void* - Add static_cast in get_psi_t() function for type conversion - Update all functions that use psi_t to use get_psi_t() or static_cast - This is the third step of gradual refactoring to remove template parameters Modified functions: - before_runner: use if-else instead of ternary operator for void* assignment - update_psi_d: use get_psi_t() to access psi_t - init: use get_psi_t() to access psi_t - copy_d2h: use get_psi_t() to access psi_t - clean: use get_psi_t() to delete psi_t * style: replace Chinese comments with English in setup_psi_pw.h - Replace '原来的模板版本' with 'Original template version' - Replace '使用 void* 存储指针,运行时类型信息记录实际类型' with 'Use void* to store pointer, runtime type information records actual type' - Follow ABACUS code style guidelines for English-only comments * refactor(psi): change psi_d from template pointer to void* - Change psi_d from psi::Psi<std::complex<double>, Device>* to void* - Add get_psi_d() accessor function for type conversion - Update all functions that use psi_d to use get_psi_d() - This is part of step 1 in phase 4 of gradual refactoring Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_io/module_ctrl/ctrl_output_pw.cpp * refactor(psi): introduce PSIPrepareBase base class for template removal This is the first step towards removing template parameters from Setup_Psi_pw. Changes: 1. Create PSIPrepareBase base class - Non-template base class for PSIPrepare<T, Device> - Similar approach to HamiltBase for Hamilt<T, Device> 2. Modify PSIPrepare to inherit from PSIPrepareBase - Add #include "source_psi/psi_prepare_base.h" - Change class declaration to inherit from PSIPrepareBase 3. Update Setup_Psi_pw to use PSIPrepareBase* - Change p_psi_init from PSIPrepare<T, Device>* to PSIPrepareBase* - Add static_cast when calling PSIPrepare methods 4. Update all PSIPrepare usage in ESolver files - esolver_ks_pw.cpp: add static_cast before prepare_init call - esolver_ks_lcaopw.cpp: add static_cast before method calls Modified files: - source/source_psi/psi_prepare_base.h (new) - source/source_psi/psi_prepare.h - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_ks_lcaopw.cpp Benefits: - Eliminates p_psi_init template dependency from Setup_Psi_pw - Paves the way for removing template parameters from Setup_Psi_pw - Maintains type safety through static_cast - Follows the same pattern as HamiltBase refactoring Tested: Compiled successfully in build_5pt and build_1p * refactor(psi): change init() parameter from template to HamiltBase* This is the second step towards removing template parameters from Setup_Psi_pw. Changes: 1. Modify init() function signature - Change parameter from hamilt::Hamilt<T, Device>* to hamilt::HamiltBase* - Eliminates template dependency in function signature 2. Update init() implementation - Add static_cast<hamilt::Hamilt<T, Device>*> inside function - Maintain type safety through explicit cast 3. Update call site in esolver_ks_pw.cpp - Remove static_cast from call site - Directly pass p_hamilt (which is already HamiltBase*) Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.cpp Benefits: - init() function no longer depends on template parameters in signature - Simplifies call sites (no cast needed) - Follows the same pattern as p_hamilt storage in ESolver_KS - One step closer to removing template parameters from Setup_Psi_pw Tested: Compiled successfully in build_5pt and build_1p * refactor(psi): remove template version of copy_d2h function This is the third step towards removing template parameters from Setup_Psi_pw. Changes: 1. Remove template version copy_d2h(const Device* ctx) - Delete the template version from setup_psi_pw.h - Delete the implementation from setup_psi_pw.cpp - Keep only the runtime version copy_d2h(const base_device::DeviceContext* ctx) 2. Update call site in ctrl_output_pw.cpp - Create Device* ctx = nullptr for template parameter deduction - Use DeviceContext::instance() for runtime device context - Call copy_d2h with DeviceContext* pointer Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_io/module_ctrl/ctrl_output_pw.cpp Benefits: - Eliminates copy_d2h function's template dependency - All member functions now use runtime device context - One step closer to removing template parameters from Setup_Psi_pw - Maintains backward compatibility with existing code Technical details: - get_pchg_pw and get_wf_pw still need Device* ctx for template deduction - DeviceContext is used for actual device type information - This follows the gradual migration pattern used in ESolver refactoring Tested: Compiled successfully in build_5pt and build_1p * refactor(psi): remove castmem_2d_d2h_op template type alias dependency This is the fourth step towards removing template parameters from Setup_Psi_pw. Changes: 1. Remove castmem_2d_d2h_op type alias from setup_psi_pw.h - The type alias depended on template parameters T and Device - Replaced with overloaded member functions 2. Add castmem_d2h_impl() overloaded functions - One overload for std::complex<double> source - One overload for std::complex<float> source - Each uses the appropriate cast_memory_op internally 3. Update copy_d2h() to use the new overloaded functions - Calls castmem_d2h_impl() instead of castmem_2d_d2h_op() - Compiler selects the correct overload based on T type Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp Benefits: - All member variables now independent of template parameters - castmem_d2h_impl encapsulates the type-dependent logic - One step closer to removing template parameters from Setup_Psi_pw Tested: Compiled successfully in build_5pt and build_1p * delete useless files * refactor(psi): remove template parameters from Setup_Psi_pw class - Remove template parameters <T, Device> from Setup_Psi_pw class - Convert member functions to template functions - Update all call sites to explicitly specify template parameters - This is a major refactoring step to enable runtime polymorphism Modified files: - source/source_psi/setup_psi_pw.h - source/source_psi/setup_psi_pw.cpp - source/source_esolver/esolver_ks_pw.h - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_io/module_ctrl/ctrl_output_pw.h - source/source_io/module_ctrl/ctrl_output_pw.cpp Key changes: 1. Setup_Psi_pw class no longer has template parameters 2. Member functions like get_psi_t(), get_psi_d(), before_runner(), init(), update_psi_d(), clean() are now template functions 3. All call sites now use stp.template get_psi_t<T, Device>() instead of stp.get_psi_t() 4. Removed template instantiation statements * fix(psi): add explicit template instantiation for Setup_Psi_pw class Fix undefined reference linker errors for template member functions: - update_psi_d<T, Device>() - copy_d2h<T, Device>() - clean<T, Device>() - before_runner<T, Device>() - init<T, Device>() - castmem_d2h_impl<T, Device>() Changes: 1. Add explicit template instantiation for CPU version: - std::complex<float>, DEVICE_CPU - std::complex<double>, DEVICE_CPU 2. Add explicit template instantiation for GPU version (conditional): - std::complex<float>, DEVICE_GPU - std::complex<double>, DEVICE_GPU - Wrapped with #if ((defined __CUDA) || (defined __ROCM)) 3. Fix template argument deduction error: - Add explicit template parameters <T, Device> when calling castmem_d2h_impl() - The template parameter T is not used in function parameters, so it cannot be deduced Root cause: Template functions defined in .cpp files require explicit instantiation for each type combination used by other compilation units. Modified files: - source/source_psi/setup_psi_pw.cpp (+91 lines) * refactor(psi): convert before_runner to non-template function Convert Setup_Psi_pw::before_runner from template function to non-template function with runtime type dispatch. Changes: 1. setup_psi_pw.h: - Remove template parameters from before_runner declaration - Add private template function before_runner_impl for internal use 2. setup_psi_pw.cpp: - Rename original before_runner to before_runner_impl (template) - Add new non-template before_runner that dispatches based on: - inp.device (gpu or cpu) - inp.precision (single or double) - Update template instantiation from before_runner to before_runner_impl 3. esolver_ks_pw.cpp: - Update call from stp.before_runner<T, Device>(...) to stp.before_runner(...) Benefits: - Caller no longer needs to specify template parameters - Type is determined at runtime from input parameters - Simpler API for ESolver * refactor(psi): convert init to non-template function Convert Setup_Psi_pw::init from template function to non-template function with runtime type dispatch based on device_type_ and precision_type_ member variables. Changes: 1. setup_psi_pw.h: - Change init parameter from Hamilt<T,Device>* to HamiltBase* - Add private template function init_impl for internal use 2. setup_psi_pw.cpp: - Rename original init to init_impl (template) - Add new non-template init that dispatches based on: - device_type_ (GpuDevice or CpuDevice) - precision_type_ (ComplexFloat or ComplexDouble) - Update template instantiation from init to init_impl 3. esolver_ks_pw.cpp: - Update call from stp.init<T, Device>(...) to stp.init(...) Design principle: before_runner sets device_type_ and precision_type_, subsequent functions use these member variables for runtime dispatch. * refactor(psi): move private member variables to private section Move the following member variables from public to private: - psi_t: accessible via get_psi_t<T, Device>() - psi_d: accessible via get_psi_d<T, Device>() - already_initpsi: internal use only - device_type_: accessible via get_device_type() - precision_type_: accessible via get_precision_type() Keep the following in public: - psi_cpu: directly accessed by 14 external locations - p_psi_init: directly accessed by 3 external locations - PrecisionType enum: used as return type of get_precision_type() This improves encapsulation while maintaining backward compatibility through accessor functions. * refactor(psi): convert clean to non-template function Convert Setup_Psi_pw::clean from template function to non-template function with runtime type dispatch based on device_type_ and precision_type_ member variables. Changes: 1. setup_psi_pw.h: - Remove template parameters from clean declaration - Add private template function clean_impl for internal use 2. setup_psi_pw.cpp: - Rename original clean to clean_impl (template) - Add new non-template clean that dispatches based on: - device_type_ (GpuDevice or CpuDevice) - precision_type_ (ComplexFloat or ComplexDouble) - Replace PARAM.inp.device/precision checks with member variables - Update template instantiation from clean to clean_impl 3. esolver_ks_pw.cpp: - Update call from stp.clean<T, Device>() to stp.clean() * refactor(psi): convert copy_d2h to non-template function Convert Setup_Psi_pw::copy_d2h from template function to non-template function with runtime type dispatch based on device_type_ and precision_type_ member variables. Changes: 1. setup_psi_pw.h: - Remove template parameters and DeviceContext parameter from copy_d2h - Add private template function copy_d2h_impl for internal use 2. setup_psi_pw.cpp: - Rename original copy_d2h to copy_d2h_impl (template) - Add new non-template copy_d2h that: - Returns early if device_type_ is not GpuDevice - Dispatches based on precision_type_ (ComplexFloat or ComplexDouble) - Update template instantiation from copy_d2h to copy_d2h_impl 3. ctrl_output_pw.cpp: - Simplify call from stp.template copy_d2h<T, Device>(device_ctx) to stp.copy_d2h() - Remove DeviceContext setup code for copy_d2h * refactor(psi): convert update_psi_d to non-template function Convert Setup_Psi_pw::update_psi_d from template function to non-template function with runtime type dispatch based on device_type_ and precision_type_ member variables. Changes: 1. setup_psi_pw.h: - Remove template parameters from update_psi_d declaration - Add private template function update_psi_d_impl for internal use 2. setup_psi_pw.cpp: - Rename original update_psi_d to update_psi_d_impl (template) - Add new non-template update_psi_d that dispatches based on: - device_type_ (GpuDevice or CpuDevice) - precision_type_ (ComplexFloat or ComplexDouble) - Replace PARAM.inp.precision checks with precision_type_ member variable - Update template instantiation from update_psi_d to update_psi_d_impl 3. esolver_ks_pw.cpp: - Update calls from stp.update_psi_d<T, Device>() to stp.update_psi_d() 4. ctrl_output_pw.cpp: - Update calls from stp.template update_psi_d<T, Device>() to stp.update_psi_d() This completes the refactoring of all Setup_Psi_pw member functions to use runtime type dispatch instead of template parameters. * refactor(exx): convert Exx_Helper to runtime polymorphism Convert Exx_Helper from template member to runtime polymorphic pointer using a base class pattern similar to Setup_Psi_pw. Changes: 1. exx_helper_base.h (new): - Create pure virtual base class Exx_HelperBase - Use void* for template-dependent parameters 2. exx_helper.h: - Exx_Helper now inherits from Exx_HelperBase - All public methods marked as override 3. exx_helper.cpp: - Update function signatures to use void* parameters - Add static_cast for type conversion 4. esolver_ks_pw.h: - Change Exx_Helper<T, Device> exx_helper to Exx_HelperBase* exx_helper 5. esolver_ks_pw.cpp: - Create concrete Exx_Helper instance based on inp.device and inp.precision - Delete exx_helper in destructor Benefits: - ESolver_KS_PW no longer requires Exx_Helper template parameters - Type determined at runtime from input parameters - Consistent with Setup_Psi_pw refactoring pattern * style: use static_cast instead of reinterpret_cast in deallocate_hamilt Replace reinterpret_cast with static_cast when deleting HamiltPW pointer. Since HamiltPW inherits from HamiltBase, static_cast is safer and more appropriate for downcasting in inheritance hierarchies. * refactor(estate): convert setup_estate_pw to non-template function Convert setup_estate_pw and teardown_estate_pw from template functions to non-template functions with runtime type dispatch based on inp.device and inp.precision. Changes: 1. setup_estate_pw.h: - Remove template parameters from setup_estate_pw and teardown_estate_pw - Add template implementation functions setup_estate_pw_impl and teardown_estate_pw_impl 2. setup_estate_pw.cpp: - Add non-template setup_estate_pw that dispatches based on: - inp.device (gpu or cpu) - inp.precision (single or double) - Rename original implementations to *_impl - Update template instantiation 3. esolver_ks_pw.cpp: - Update calls from setup_estate_pw<T, Device>(...) to setup_estate_pw(...) - Update calls from teardown_estate_pw<T, Device>(...) to teardown_estate_pw(...) This simplifies the calling code in ESolver_KS_PW by removing template parameters while maintaining the same functionality through runtime dispatch. --------- Co-authored-by: abacus_fixer <mohanchen@pku.eud.cn>
1 parent 49a142a commit 342de19

13 files changed

Lines changed: 631 additions & 291 deletions

File tree

source/source_esolver/esolver_ks_lcaopw.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ namespace ModuleESolver
107107
ucell.symm,
108108
&this->kv,
109109
this->psi_local,
110-
this->stp.get_psi_t(),
110+
this->stp.template get_psi_t<T, base_device::DEVICE_CPU>(),
111111
this->pw_wfc,
112112
this->pw_rho,
113113
this->sf,
@@ -148,7 +148,7 @@ namespace ModuleESolver
148148
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
149149

150150
hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
151-
hsolver_lip_obj.solve(static_cast<hamilt::Hamilt<T>*>(this->p_hamilt), *this->stp.get_psi_t(), this->pelec,
151+
hsolver_lip_obj.solve(static_cast<hamilt::Hamilt<T>*>(this->p_hamilt), *this->stp.template get_psi_t<T, base_device::DEVICE_CPU>(), this->pelec,
152152
*this->psi_local, skip_charge,ucell.tpiba,ucell.nat);
153153

154154
// add exx
@@ -242,7 +242,7 @@ namespace ModuleESolver
242242
ModuleIO::write_Vxc(PARAM.inp.nspin,
243243
PARAM.globalv.nlocal,
244244
GlobalV::DRANK,
245-
*this->stp.get_psi_t(),
245+
*this->stp.template get_psi_t<T, base_device::DEVICE_CPU>(),
246246
ucell,
247247
this->sf,
248248
this->solvent,

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
5454
// delete Hamilt
5555
this->deallocate_hamilt();
5656

57+
// delete exx_helper
58+
if (this->exx_helper != nullptr)
59+
{
60+
delete this->exx_helper;
61+
this->exx_helper = nullptr;
62+
}
63+
5764
// mohan add 2025-10-12
5865
this->stp.clean();
5966
}
@@ -75,7 +82,7 @@ void ESolver_KS_PW<T, Device>::deallocate_hamilt()
7582
{
7683
if (this->p_hamilt != nullptr)
7784
{
78-
delete reinterpret_cast<hamilt::HamiltPW<T, Device>*>(this->p_hamilt);
85+
delete static_cast<hamilt::HamiltPW<T, Device>*>(this->p_hamilt);
7986
this->p_hamilt = nullptr;
8087
}
8188
}
@@ -86,16 +93,45 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
8693
ESolver_KS::before_all_runners(ucell, inp);
8794

8895
//! setup and allocation for pelec, potentials, etc.
89-
elecstate::setup_estate_pw<T, Device>(ucell, this->kv, this->sf, this->pelec, this->chr,
96+
elecstate::setup_estate_pw(ucell, this->kv, this->sf, this->pelec, this->chr,
9097
this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho,
9198
this->pw_rhod, this->pw_big, this->solvent, inp);
9299

93100
this->stp.before_runner(ucell, this->kv, this->sf, *this->pw_wfc, this->ppcell, PARAM.inp);
94101

95102
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS");
96103

104+
//! Create exx_helper based on device and precision
105+
const bool is_gpu = (inp.device == "gpu");
106+
const bool is_single = (inp.precision == "single");
107+
108+
#if ((defined __CUDA) || (defined __ROCM))
109+
if (is_gpu)
110+
{
111+
if (is_single)
112+
{
113+
this->exx_helper = new Exx_Helper<std::complex<float>, base_device::DEVICE_GPU>();
114+
}
115+
else
116+
{
117+
this->exx_helper = new Exx_Helper<std::complex<double>, base_device::DEVICE_GPU>();
118+
}
119+
}
120+
else
121+
#endif
122+
{
123+
if (is_single)
124+
{
125+
this->exx_helper = new Exx_Helper<std::complex<float>, base_device::DEVICE_CPU>();
126+
}
127+
else
128+
{
129+
this->exx_helper = new Exx_Helper<std::complex<double>, base_device::DEVICE_CPU>();
130+
}
131+
}
132+
97133
//! Initialize exx pw
98-
this->exx_helper.init(ucell, inp, this->pelec->wg);
134+
this->exx_helper->init(ucell, inp, this->pelec->wg);
99135
}
100136

101137
template <typename T, typename Device>
@@ -128,13 +164,13 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
128164
// init DFT+U is done in "before_all_runners" in LCAO basis. This should be refactored, mohan note 2025-11-06
129165
pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid,
130166
this->chr, this->locpp, this->ppcell, this->dftu, this->vsep_cell,
131-
this->stp.get_psi_t(), static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp);
167+
this->stp.template get_psi_t<T, Device>(), static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), this->pw_wfc, this->pw_rhod, PARAM.inp);
132168

133169
// setup psi (electronic wave functions)
134170
this->stp.init(this->p_hamilt);
135171

136172
//! Setup EXX helper for Hamiltonian and psi
137-
exx_helper.before_scf(this->p_hamilt, this->stp.get_psi_t(), PARAM.inp);
173+
exx_helper->before_scf(this->p_hamilt, this->stp.template get_psi_t<T, Device>(), PARAM.inp);
138174

139175
ModuleBase::timer::tick("ESolver_KS_PW", "before_scf");
140176
}
@@ -152,7 +188,7 @@ void ESolver_KS_PW<T, Device>::iter_init(UnitCell& ucell, const int istep, const
152188

153189
// update local occupations for DFT+U
154190
// should before lambda loop in DeltaSpin
155-
pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.get_psi_t(), this->pelec->wg, ucell, PARAM.inp);
191+
pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.template get_psi_t<T, Device>(), this->pelec->wg, ucell, PARAM.inp);
156192
}
157193

158194
// Temporary, it should be replaced by hsolver later.
@@ -188,7 +224,7 @@ void ESolver_KS_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, const int iste
188224
hsolver::DiagoIterAssist<T, Device>::need_subspace,
189225
PARAM.inp.use_k_continuity);
190226

191-
hsolver_pw_obj.solve(static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), *this->stp.get_psi_t(), this->pelec, this->pelec->ekb.c,
227+
hsolver_pw_obj.solve(static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt), *this->stp.template get_psi_t<T, Device>(), this->pelec, this->pelec->ekb.c,
192228
GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, skip_charge, ucell.tpiba, ucell.nat);
193229
}
194230

@@ -203,9 +239,9 @@ template <typename T, typename Device>
203239
void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver)
204240
{
205241
// Related to EXX
206-
if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter)
242+
if (GlobalC::exx_info.info_global.cal_exx && !exx_helper->get_op_first_iter())
207243
{
208-
this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.get_psi_t()));
244+
this->pelec->set_exx(exx_helper->cal_exx_energy(this->stp.template get_psi_t<T, Device>()));
209245
}
210246

211247
// deband is calculated from "output" charge density
@@ -224,7 +260,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
224260
}
225261

226262
// Handle EXX-related operations after SCF iteration
227-
exx_helper.iter_finish(this->pelec, &this->chr, this->stp.get_psi_t(), ucell, PARAM.inp, conv_esolver, iter);
263+
exx_helper->iter_finish(this->pelec, &this->chr, this->stp.template get_psi_t<T, Device>(), ucell, PARAM.inp, conv_esolver, iter);
228264

229265
// check if oscillate for delta_spin method
230266
pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp);
@@ -273,7 +309,7 @@ void ESolver_KS_PW<T, Device>::cal_force(UnitCell& ucell, ModuleBase::matrix& fo
273309
// Calculate forces
274310
ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm,
275311
&this->sf, this->solvent, &this->dftu, &this->locpp, &this->ppcell,
276-
&this->kv, this->pw_wfc, this->stp.get_psi_d());
312+
&this->kv, this->pw_wfc, this->stp.template get_psi_d<T, Device>());
277313
}
278314

279315
template <typename T, typename Device>
@@ -285,7 +321,7 @@ void ESolver_KS_PW<T, Device>::cal_stress(UnitCell& ucell, ModuleBase::matrix& s
285321
this->stp.update_psi_d();
286322

287323
ss.cal_stress(stress, ucell, this->dftu, this->locpp, this->ppcell, this->pw_rhod,
288-
&ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.get_psi_d());
324+
&ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.template get_psi_d<T, Device>());
289325

290326
// external stress
291327
double unit_transform = 0.0;
@@ -306,7 +342,7 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
306342
this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp,
307343
this->sf, this->ppcell, this->solvent, this->Pgrid, PARAM.inp);
308344

309-
elecstate::teardown_estate_pw<T, Device>(this->pelec, this->vsep_cell);
345+
elecstate::teardown_estate_pw(this->pelec, this->vsep_cell);
310346

311347
}
312348

source/source_esolver/esolver_ks_pw.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "./esolver_ks.h"
44
#include "source_psi/setup_psi_pw.h" // mohan add 20251012
55
#include "source_pw/module_pwdft/vsep_pw.h"
6-
#include "source_pw/module_pwdft/exx_helper.h"
6+
#include "source_pw/module_pwdft/exx_helper_base.h"
77
#include "source_pw/module_pwdft/op_pw_vel.h"
88

99
#include <memory>
@@ -33,7 +33,7 @@ class ESolver_KS_PW : public ESolver_KS
3333

3434
void after_all_runners(UnitCell& ucell) override;
3535

36-
Exx_Helper<T, Device> exx_helper;
36+
Exx_HelperBase* exx_helper = nullptr;
3737

3838
protected:
3939
virtual void before_scf(UnitCell& ucell, const int istep) override;
@@ -52,7 +52,7 @@ class ESolver_KS_PW : public ESolver_KS
5252
virtual void deallocate_hamilt();
5353

5454
// Electronic wave function psi
55-
Setup_Psi_pw<T, Device> stp;
55+
Setup_Psi_pw stp;
5656

5757
// DFT-1/2 method
5858
VSep* vsep_cell = nullptr;

source/source_esolver/esolver_sdft_pw.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, int istep, i
168168

169169
hsolver_pw_sdft_obj.solve(ucell,
170170
static_cast<hamilt::Hamilt<T, Device>*>(this->p_hamilt),
171-
*this->stp.get_psi_t(),
171+
*this->stp.template get_psi_t<T, Device>(),
172172
this->stp.psi_cpu[0],
173173
this->pelec,
174174
this->pw_wfc,
@@ -221,7 +221,7 @@ void ESolver_SDFT_PW<T, Device>::cal_force(UnitCell& ucell, ModuleBase::matrix&
221221
this->locpp,
222222
this->ppcell,
223223
ucell,
224-
*this->stp.get_psi_t(),
224+
*this->stp.template get_psi_t<T, Device>(),
225225
this->stowf);
226226
}
227227

@@ -236,7 +236,7 @@ void ESolver_SDFT_PW<T, Device>::cal_stress(UnitCell& ucell, ModuleBase::matrix&
236236
&this->sf,
237237
&this->kv,
238238
this->pw_wfc,
239-
*this->stp.get_psi_t(),
239+
*this->stp.template get_psi_t<T, Device>(),
240240
this->stowf,
241241
&this->chr,
242242
&this->locpp,
@@ -289,7 +289,7 @@ void ESolver_SDFT_PW<T, Device>::after_all_runners(UnitCell& ucell)
289289
&this->kv,
290290
this->pelec,
291291
this->pw_wfc,
292-
this->stp.get_psi_t(),
292+
this->stp.template get_psi_t<T, Device>(),
293293
&this->ppcell,
294294
static_cast<hamilt::Hamilt<std::complex<double>, Device>*>(this->p_hamilt),
295295
this->stoche,

0 commit comments

Comments
 (0)