Skip to content

Commit a10e537

Browse files
mohanchenabacus_fixer
andauthored
Refactor device parameter in esolver_ks_pw (#7016)
* small format changes * refactor(esolver): extract charge density symmetrization to Symmetry_rho::symmetrize_rho - Add static method symmetrize_rho() in Symmetry_rho class - Replace 7 duplicate code blocks with single function call - Simplify code from 35 lines to 7 lines (80% reduction) - Improve code readability and maintainability Modified files: - source_estate/module_charge/symmetry_rho.h: add static method declaration - source_estate/module_charge/symmetry_rho.cpp: implement static method - source_esolver/esolver_ks_lcao.cpp: 2 calls updated - source_esolver/esolver_ks_pw.cpp: 1 call updated - source_esolver/esolver_ks_lcao_tddft.cpp: 1 call updated - source_esolver/esolver_ks_lcaopw.cpp: 1 call updated - source_esolver/esolver_of.cpp: 1 call updated - source_esolver/esolver_sdft_pw.cpp: 1 call updated This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract DeltaSpin lambda loop to deltaspin_lcao module - Create new files deltaspin_lcao.h/cpp in module_deltaspin - Extract DeltaSpin lambda loop logic from ESolver_KS_LCAO - Simplify code from 18 lines to 1 line in hamilt2rho_single - Separate LCAO and PW implementations for DeltaSpin Modified files: - source_esolver/esolver_ks_lcao.cpp: replace inline code with function call - source_lcao/module_deltaspin/CMakeLists.txt: add new source file New files: - source_lcao/module_deltaspin/deltaspin_lcao.h: function declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: function implementation This refactoring follows the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): complete DeltaSpin refactoring in LCAO - Add init_deltaspin_lcao() function for DeltaSpin initialization - Add cal_mi_lcao_wrapper() function for magnetic moment calculation - Refactor all DeltaSpin-related code in esolver_ks_lcao.cpp - Simplify code from 29 lines to 3 lines (90% reduction) Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 3 code blocks with function calls - source_lcao/module_deltaspin/deltaspin_lcao.h: add 2 new function declarations - source_lcao/module_deltaspin/deltaspin_lcao.cpp: implement 2 new functions This completes the DeltaSpin refactoring for LCAO method: 1. init_deltaspin_lcao() - initialize DeltaSpin calculation 2. cal_mi_lcao_wrapper() - calculate magnetic moments 3. run_deltaspin_lambda_loop_lcao() - run lambda loop optimization All functions follow the ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract DFT+U code to dftu_lcao module - Create new files dftu_lcao.h/cpp in source_lcao directory - Add init_dftu_lcao() function for DFT+U initialization - Add finish_dftu_lcao() function for DFT+U finalization - Simplify code from 32 lines to 2 lines in esolver_ks_lcao.cpp - Remove conditional checks from ESolver, move them to functions Modified files: - source_esolver/esolver_ks_lcao.cpp: replace 2 code blocks with function calls - source_lcao/CMakeLists.txt: add new source file New files: - source_lcao/dftu_lcao.h: function declarations - source_lcao/dftu_lcao.cpp: function implementations This refactoring prepares for unifying old and new DFT+U implementations: - Old DFT+U: source_lcao/module_dftu/ - New DFT+U: source_lcao/module_operator_lcao/op_dftu_lcao.cpp All functions follow ESolver cleanup principle: keep ESolver focused on high-level workflow control. * refactor(esolver): extract diagonalization parameters setup to hsolver module - Create new files diago_params.h/cpp in source_hsolver directory - Add setup_diago_params_pw() function for PW diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_ks_pw.cpp - Encapsulate diagonalization parameter setup logic Modified files: - source_esolver/esolver_ks_pw.cpp: replace inline code with function call - source_hsolver/CMakeLists.txt: add new source file New files: - source_hsolver/diago_params.h: function declaration - source_hsolver/diago_params.cpp: function implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. * fix(deltaspin): add sc_mag_switch check in cal_mi_lcao_wrapper - Add Input_para parameter to cal_mi_lcao_wrapper function - Add sc_mag_switch check to avoid calling cal_mi_lcao when DeltaSpin is disabled - Fix 'atomCounts is not set' error in non-DeltaSpin calculations - Update function call in esolver_ks_lcao.cpp This fix resolves the CI/CD failure caused by commit 2a520e3. The root cause was that cal_mi_lcao_wrapper was called without checking sc_mag_switch, leading to uninitialized atomCounts error. Modified files: - source_esolver/esolver_ks_lcao.cpp: update function call - source_lcao/module_deltaspin/deltaspin_lcao.h: add parameter - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add check This follows the refactoring principle: preserve original condition checks when extracting code to wrapper functions. * fix(deltaspin): add #ifdef __LCAO for conditional compilation - Add #ifdef __LCAO conditional compilation in init_deltaspin_lcao and cal_mi_lcao_wrapper - Fix parameter order in init_sc call for LCAO and non-LCAO builds - Fix undefined reference to cal_mi_lcao in non-LCAO build This fix resolves CI/CD compilation errors in both build_5pt (with __LCAO) and build_1p (without __LCAO) environments. The The root cause was 1. init_sc has different parameter order in LCAO vs non-LCAO builds - LCAO: psi, dm, pelec - non-LCAO: psi, pelec 2. cal_mi_lcao is only defined in LCAO build Modified files: - source_hsolver/diago_params.h: add setup_diago_params_sdft declaration - source_lcao/module_deltaspin/deltaspin_lcao.cpp: add conditional compilation This follows the refactoring principle: handle conditional compilation properly when code has different implementations for different build configurations. * refactor(esolver): extract SDFT diagonalization parameters setup - Add setup_diago_params_sdft() function for SDFT diagonalization parameters - Simplify code from 11 lines to 1 line in esolver_sdft_pw.cpp - Encapsulate diagonalization parameter setup logic for SDFT Modified files: - source_esolver/esolver_sdft_pw.cpp: replace inline code with function call - source_hsolver/diago_params.cpp: add setup_diago_params_sdft implementation This refactoring follows ESolver cleanup principle: keep ESolver focused on high-level workflow control. Note: SDFT has different parameter setup logic compared to PW: - Different need_subspace condition - No SCF_ITER setting - Always set PW_DIAG_NMAX (no nscf check) * refactor(hamilt): introduce HamiltBase non-template base class - Create HamiltBase as a non-template base class for Hamilt<T, Device> - Modify Hamilt<T, Device> to inherit from HamiltBase - Change ESolver_KS::p_hamilt type from Hamilt<T, Device>* to HamiltBase* - Add static_cast where needed when passing p_hamilt to functions expecting Hamilt<T, Device>* This is the first step towards removing template parameters from ESolver. Modified files: - source/source_esolver/esolver_ks.h - source/source_esolver/esolver_ks_lcaopw.cpp - source/source_esolver/esolver_ks_pw.cpp - source/source_esolver/esolver_sdft_pw.cpp - source/source_hamilt/hamilt.h New files: - source/source_hamilt/hamilt_base.h * refactor(esolver): add static_cast for p_hamilt in esolver files - Add static_cast<hamilt::Hamilt<T>*> when passing p_hamilt to functions expecting Hamilt<T, Device>* type - Split long cast statements into multiple lines for better readability - Files modified: - esolver_ks_pw.cpp: setup_pot, stp.init calls - esolver_ks_lcao.cpp: init_chg_hr, hsolver_lcao_obj.solve calls - esolver_ks_lcao_tddft.cpp: solve_psi, cal_edm_tddft, matrix calls - esolver_gets.cpp: ops access, output_SR call This follows the HamiltBase refactoring strategy where p_hamilt is stored as HamiltBase* and cast to Hamilt<T, Device>* when needed. * refactor(esolver): remove psi member from ESolver_KS base class Move psi::Psi<T>* psi from ESolver_KS base class to derived classes to eliminate template parameter dependency and improve code organization. Changes: 1. ESolver_KS base class: - Remove psi::Psi<T>* psi member variable - Remove Setup_Psi<T>::deallocate_psi() call in destructor - Remove unnecessary includes: psi.h and setup_psi.h 2. ESolver_KS_LCAO: - Add psi::Psi<TK>* psi member variable - Add Setup_Psi<TK>::deallocate_psi() in destructor - Add include: setup_psi.h 3. ESolver_KS_LCAO_TDDFT: - Improve psi_laststep deallocation with nullptr check - psi member inherited from ESolver_KS_LCAO 4. ESolver_KS_PW: - Use stp.psi_cpu directly instead of base class psi - Remove unnecessary memory allocation in after_scf() 5. pw_others.cpp (BUG FIX): - Fix gen_bessel: use *(this->stp.psi_cpu) instead of this->psi[0] - Previous code accessed uninitialized base class psi (nullptr) - This was a latent bug that could cause crashes Benefits: - Eliminates template parameter T dependency in ESolver_KS base class - Clearer memory management: each derived class manages its own psi - Reduces compilation dependencies - Fixes potential memory access bug in pw_others.cpp Tested: Compiled successfully in build_5pt and build_1p * refactor(esolver): remove template parameters from ESolver_KS base class This is a major milestone in ESolver refactoring! ESolver_KS no longer needs template parameters because: - All member variables are non-template types - All member functions do not use T or Device parameters - Template parameters were only needed for derived classes Changes: 1. ESolver_KS base class: - Remove template <typename T, typename Device> declaration - Remove all template declarations from member functions - Remove template instantiation code at end of file - Fix Tab indentation to spaces for better readability 2. Derived classes: - ESolver_KS_PW: public ESolver_KS (was ESolver_KS<T, Device>) - ESolver_KS_LCAO: public ESolver_KS (was ESolver_KS<TK>) - ESolver_GetS: public ESolver_KS (was ESolver_KS<std::complex<double>>) - Update base class calls: ESolver_KS:: (was ESolver_KS<T, Device>::) Code reduction: - esolver_ks.h: 78 -> 77 lines (-1 line) - esolver_ks.cpp: 346 -> 317 lines (-29 lines) - Total ESolver code: 424 -> 394 lines (-30 lines) - Overall: 8 files changed, 50 insertions(+), 80 deletions(-), net -30 lines Benefits: - Simpler base class without template complexity - Faster compilation (no template instantiation needed) - Clearer inheritance hierarchy - Easier to extract common code in future refactoring - Sets foundation for further ESolver template removal Tested: Compiled successfully in build_5pt * refactor(device): remove explicit template parameter from get_device_type calls - Move get_device_type implementation to header file using std::is_same - Add DEVICE_DSP support - Remove template specialization declarations and definitions - Update all call sites to use automatic template parameter deduction - The compiler now deduces Device type from the ctx parameter * refactor(esolver): remove device member variable from ESolver_KS_PW - Modify copy_d2h to accept ctx parameter and call get_device_type internally - Remove device parameter from ctrl_scf_pw function - Remove device member variable from ESolver_KS_PW class - Simplify function interfaces by using automatic template deduction * style(esolver): explicitly initialize ctx to nullptr in constructor * feat(device): add runtime device type support to DeviceContext - Add device_type_ member variable to DeviceContext class - Add set_device_type() and get_device_type() methods - Add is_cpu(), is_gpu(), is_dsp() convenience methods - Add get_device_type(const DeviceContext*) overload for runtime device type query - Maintain backward compatibility with existing template-based get_device_type * feat(device): add runtime device context overloads for gradual migration - Add copy_d2h(const DeviceContext*) overload to Setup_Psi_pw - Add ctrl_scf_pw(..., const DeviceContext*, ...) overload - Add ctrl_runner_pw(..., const DeviceContext*, ...) overload - Keep original functions for backward compatibility - Replace tabs with spaces in modified files --------- Co-authored-by: abacus_fixer <mohanchen@pku.eud.cn>
1 parent 19f8df0 commit a10e537

29 files changed

Lines changed: 187 additions & 105 deletions

source/source_base/math_chebyshev.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Chebyshev<REAL, Device>::Chebyshev(const int norder_in) : fftw(2 * EXTEND * nord
6161
}
6262
coefr_cpu = new REAL[norder];
6363
coefc_cpu = new std::complex<REAL>[norder];
64-
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
64+
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
6565
{
6666
resmem_var_op()(this->coef_real, norder);
6767
resmem_complex_op()(this->coef_complex, norder);
@@ -82,7 +82,7 @@ template <typename REAL, typename Device>
8282
Chebyshev<REAL, Device>::~Chebyshev()
8383
{
8484
delete[] polytrace;
85-
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
85+
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
8686
{
8787
delmem_var_op()(this->coef_real);
8888
delmem_complex_op()(this->coef_complex);
@@ -209,7 +209,7 @@ void Chebyshev<REAL, Device>::calcoef_real(std::function<REAL(REAL)> fun)
209209
}
210210
}
211211

212-
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
212+
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
213213
{
214214
syncmem_var_h2d_op()(coef_real, coefr_cpu, norder);
215215
}
@@ -299,7 +299,7 @@ void Chebyshev<REAL, Device>::calcoef_complex(std::function<std::complex<REAL>(s
299299
coefc_cpu[i].imag(imag(coefc_cpu[i]) + real(pcoef[i]) / norder2 * 2 / 3);
300300
}
301301
}
302-
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
302+
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
303303
{
304304
syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder);
305305
}
@@ -390,7 +390,7 @@ void Chebyshev<REAL, Device>::calcoef_pair(std::function<REAL(REAL)> fun1, std::
390390
}
391391
}
392392

393-
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
393+
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
394394
{
395395
syncmem_complex_h2d_op()(coef_complex, coefc_cpu, norder);
396396
}
@@ -684,7 +684,7 @@ bool Chebyshev<REAL, Device>::checkconverge(
684684
funA(arrayn_1, arrayn, 1);
685685
REAL sum1, sum2;
686686
REAL t;
687-
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
687+
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
688688
{
689689
sum1 = this->ddot_real(arrayn_1, arrayn_1, N);
690690
sum2 = this->ddot_real(arrayn_1, arrayn, N);
@@ -714,7 +714,7 @@ bool Chebyshev<REAL, Device>::checkconverge(
714714
for (int ior = 2; ior < norder; ++ior)
715715
{
716716
funA(arrayn, arraynp1, 1);
717-
if (base_device::get_device_type<Device>(this->ctx) == base_device::GpuDevice)
717+
if (base_device::get_device_type(this->ctx) == base_device::GpuDevice)
718718
{
719719
sum1 = this->ddot_real(arrayn, arrayn, N);
720720
sum2 = this->ddot_real(arrayn, arraynp1, N);

source/source_base/module_device/device.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,36 @@ class DeviceContext {
145145
*/
146146
int get_local_rank() const { return local_rank_; }
147147

148+
/**
149+
* @brief Set the device type (CpuDevice, GpuDevice, or DspDevice)
150+
* @param type The device type
151+
*/
152+
void set_device_type(AbacusDevice_t type) { device_type_ = type; }
153+
154+
/**
155+
* @brief Get the device type
156+
* @return AbacusDevice_t The device type
157+
*/
158+
AbacusDevice_t get_device_type() const { return device_type_; }
159+
160+
/**
161+
* @brief Check if the device is CPU
162+
* @return true if the device is CPU
163+
*/
164+
bool is_cpu() const { return device_type_ == CpuDevice; }
165+
166+
/**
167+
* @brief Check if the device is GPU
168+
* @return true if the device is GPU
169+
*/
170+
bool is_gpu() const { return device_type_ == GpuDevice; }
171+
172+
/**
173+
* @brief Check if the device is DSP
174+
* @return true if the device is DSP
175+
*/
176+
bool is_dsp() const { return device_type_ == DspDevice; }
177+
148178
// Disable copy and assignment
149179
DeviceContext(const DeviceContext&) = delete;
150180
DeviceContext& operator=(const DeviceContext&) = delete;
@@ -158,10 +188,21 @@ class DeviceContext {
158188
int device_id_ = -1;
159189
int device_count_ = 0;
160190
int local_rank_ = 0;
191+
AbacusDevice_t device_type_ = CpuDevice;
161192

162193
std::mutex init_mutex_;
163194
};
164195

196+
/**
197+
* @brief Get the device type enum from DeviceContext (runtime version).
198+
* @param ctx Pointer to DeviceContext
199+
* @return AbacusDevice_t enum value
200+
*/
201+
inline AbacusDevice_t get_device_type(const DeviceContext* ctx)
202+
{
203+
return ctx->get_device_type();
204+
}
205+
165206
} // end of namespace base_device
166207

167208
#endif // MODULE_DEVICE_H_

source/source_base/module_device/device_helpers.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,6 @@
33
namespace base_device
44
{
55

6-
// Device type specializations
7-
template <>
8-
AbacusDevice_t get_device_type<DEVICE_CPU>(const DEVICE_CPU* dev)
9-
{
10-
return CpuDevice;
11-
}
12-
13-
template <>
14-
AbacusDevice_t get_device_type<DEVICE_GPU>(const DEVICE_GPU* dev)
15-
{
16-
return GpuDevice;
17-
}
18-
196
// Precision specializations
207
template <>
218
std::string get_current_precision<float>(const float* var)

source/source_base/module_device/device_helpers.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,35 @@
1313
#include "types.h"
1414
#include <complex>
1515
#include <string>
16+
#include <type_traits>
1617

1718
namespace base_device
1819
{
1920

21+
// Forward declaration
22+
class DeviceContext;
23+
24+
/**
25+
* @brief Get the device type enum from DeviceContext (runtime version).
26+
* @param ctx Pointer to DeviceContext
27+
* @return AbacusDevice_t enum value
28+
*/
29+
inline AbacusDevice_t get_device_type(const DeviceContext* ctx);
30+
2031
/**
21-
* @brief Get the device type enum for a given device type.
32+
* @brief Get the device type enum for a given device type (compile-time version).
2233
* @tparam Device The device type (DEVICE_CPU or DEVICE_GPU)
2334
* @param dev Pointer to device (used for template deduction)
2435
* @return AbacusDevice_t enum value
2536
*/
2637
template <typename Device>
27-
AbacusDevice_t get_device_type(const Device* dev);
28-
29-
// Template specialization declarations
30-
template <>
31-
AbacusDevice_t get_device_type<DEVICE_CPU>(const DEVICE_CPU* dev);
32-
33-
template <>
34-
AbacusDevice_t get_device_type<DEVICE_GPU>(const DEVICE_GPU* dev);
38+
AbacusDevice_t get_device_type(const Device* dev)
39+
{
40+
if (std::is_same<Device, DEVICE_CPU>::value) return CpuDevice;
41+
else if (std::is_same<Device, DEVICE_GPU>::value) return GpuDevice;
42+
else if (std::is_same<Device, DEVICE_DSP>::value) return DspDevice;
43+
else return UnKnown;
44+
}
3545

3646
/**
3747
* @brief Get the precision string for a given numeric type.

source/source_base/module_device/test/device_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ class TestModulePsiDevice : public ::testing::Test
2020

2121
TEST_F(TestModulePsiDevice, get_device_type_cpu)
2222
{
23-
base_device::AbacusDevice_t device = base_device::get_device_type<base_device::DEVICE_CPU>(cpu_ctx);
23+
base_device::AbacusDevice_t device = base_device::get_device_type(cpu_ctx);
2424
EXPECT_EQ(device, base_device::CpuDevice);
2525
}
2626

2727
#if __UT_USE_CUDA || __UT_USE_ROCM
2828
TEST_F(TestModulePsiDevice, get_device_type_gpu)
2929
{
30-
base_device::AbacusDevice_t device = base_device::get_device_type<base_device::DEVICE_GPU>(gpu_ctx);
30+
base_device::AbacusDevice_t device = base_device::get_device_type(gpu_ctx);
3131
EXPECT_EQ(device, base_device::GpuDevice);
3232
}
3333
#endif // __UT_USE_CUDA || __UT_USE_ROCM

source/source_esolver/esolver_ks_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
4343
{
4444
this->classname = "ESolver_KS_PW";
4545
this->basisname = "PW";
46-
this->device = base_device::get_device_type<Device>(this->ctx);
46+
this->ctx = nullptr;
4747
}
4848

4949
template <typename T, typename Device>
@@ -251,7 +251,7 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
251251
// Output quantities
252252
ModuleIO::ctrl_scf_pw<T, Device>(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc,
253253
this->pw_rho, this->pw_rhod, this->pw_big, this->stp,
254-
this->ctx, this->device, this->Pgrid, PARAM.inp);
254+
this->ctx, this->Pgrid, PARAM.inp);
255255

256256
ModuleBase::timer::tick("ESolver_KS_PW", "after_scf");
257257
}

source/source_esolver/esolver_ks_pw.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ class ESolver_KS_PW : public ESolver_KS
6060
// for get_pchg and get_wf, use ctx as input of fft
6161
Device* ctx = {};
6262

63-
// for device to host data transformation
64-
base_device::AbacusDevice_t device = {};
65-
6663
};
6764
} // namespace ModuleESolver
6865
#endif

source/source_hsolver/diago_dav_subspace.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
3636
diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), diag_comm(diag_comm_in),
3737
diag_subspace(diag_subspace_in), diago_subspace_bs(diago_subspace_bs_in)
3838
{
39-
this->device = base_device::get_device_type<Device>(this->ctx);
39+
this->device = base_device::get_device_type(this->ctx);
4040

4141
this->one = &one_;
4242
this->zero = &zero_;

source/source_hsolver/diago_david.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
2020
const diag_comm_info& diag_comm_in)
2121
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in)
2222
{
23-
this->device = base_device::get_device_type<Device>(this->ctx);
23+
this->device = base_device::get_device_type(this->ctx);
2424
this->precondition = precondition_in;
2525

2626
this->one = &one_;

source/source_hsolver/diago_iter_assist.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,14 +400,14 @@ void DiagoIterAssist<T, Device>::diag_heevx(const int matrix_size,
400400
// (const Device *d, const int matrix_size, const int lda, const T *A, const int num_eigenpairs, Real *eigenvalues, T *eigenvectors);
401401
heevx_op<T, Device>()(ctx, matrix_size, ldh, h, num_eigenpairs, eigenvalues, v);
402402

403-
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
403+
if (base_device::get_device_type(ctx) == base_device::GpuDevice)
404404
{
405405
#if ((defined __CUDA) || (defined __ROCM))
406406
// eigenvalues to e, from device to host
407407
syncmem_var_d2h_op()(e, eigenvalues, num_eigenpairs);
408408
#endif
409409
}
410-
else if (base_device::get_device_type<Device>(ctx) == base_device::CpuDevice)
410+
else if (base_device::get_device_type(ctx) == base_device::CpuDevice)
411411
{
412412
// eigenvalues to e
413413
syncmem_var_op()(e, eigenvalues, num_eigenpairs);
@@ -436,14 +436,14 @@ void DiagoIterAssist<T, Device>::diag_hegvd(const int nstart,
436436

437437
hegvd_op<T, Device>()(ctx, nstart, ldh, hcc, scc, eigenvalues, vcc);
438438

439-
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
439+
if (base_device::get_device_type(ctx) == base_device::GpuDevice)
440440
{
441441
#if ((defined __CUDA) || (defined __ROCM))
442442
// set eigenvalues in GPU to e in CPU
443443
syncmem_var_d2h_op()(e, eigenvalues, nbands);
444444
#endif
445445
}
446-
else if (base_device::get_device_type<Device>(ctx) == base_device::CpuDevice)
446+
else if (base_device::get_device_type(ctx) == base_device::CpuDevice)
447447
{
448448
// set eigenvalues in CPU to e in CPU
449449
syncmem_var_op()(e, eigenvalues, nbands);

0 commit comments

Comments
 (0)