Skip to content

Commit 233254a

Browse files
committed
Replace hsolver::heevx_op with ct::kernels::lapack_heevx in diago_david
1 parent 93da750 commit 233254a

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

source/source_hsolver/diago_david.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using namespace hsolver;
1212

1313

1414
template <typename T, typename Device>
15-
DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
15+
DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
1616
const int nband_in,
1717
const int dim_in,
1818
const int david_ndim_in,
@@ -80,7 +80,7 @@ DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
8080
resmem_complex_op()(this->vcc, nbase_x * nbase_x, "DAV::vcc");
8181
setmem_complex_op()(this->vcc, 0, nbase_x * nbase_x);
8282
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
83-
83+
8484
// lagrange_matrix(nband, nband); // for orthogonalization
8585
resmem_complex_op()(this->lagrange_matrix, nband * nband);
8686
setmem_complex_op()(this->lagrange_matrix, 0, nband * nband);
@@ -409,7 +409,7 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
409409
// basis[nbase] = basis[nbase] - spsi * vc_ev_vector
410410
// = hpsi - spsi * lambda * vcc
411411
// = (H - lambda * S) * psi * vcc
412-
// = (H - lambda * S) * psi_new
412+
// = (H - lambda * S) * psi_new
413413
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
414414
ModuleBase::gemm_op<T, Device>()('N',
415415
'N',
@@ -622,15 +622,17 @@ void DiagoDavid<T, Device>::diag_zhegvx(const int& nbase,
622622
resmem_var_op()(eigenvalue_gpu, nbase_x);
623623
syncmem_var_h2d_op()(eigenvalue_gpu, this->eigenvalue, nbase_x);
624624

625-
heevx_op<T, Device>()(this->ctx, nbase, nbase_x, hcc, nband, eigenvalue_gpu, vcc);
625+
// heevx_op<T, Device>()(this->ctx, nbase, nbase_x, hcc, nband, eigenvalue_gpu, vcc);
626+
ct::kernels::lapack_heevx<T, ct_Device>()(nbase, nbase_x, hcc, nband, eigenvalue_gpu, vcc);
626627

627628
syncmem_var_d2h_op()(this->eigenvalue, eigenvalue_gpu, nbase_x);
628629
delmem_var_op()(eigenvalue_gpu);
629630
#endif
630631
}
631632
else
632633
{
633-
heevx_op<T, Device>()(this->ctx, nbase, nbase_x, hcc, nband, this->eigenvalue, vcc);
634+
//heevx_op<T, Device>()(this->ctx, nbase, nbase_x, hcc, nband, this->eigenvalue, vcc);
635+
ct::kernels::lapack_heevx<T, ct_Device>()(nbase, nbase_x, hcc, nband, this->eigenvalue, vcc);
634636
}
635637
}
636638

source/source_hsolver/diago_david.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "source_base/module_device/device.h" // base_device
66
#include "source_base/module_device/memory_op.h"// base_device::memory
77

8-
// #include "source_base/module_container/ATen/kernels/lapack.h" // container::kernels
8+
#include "source_base/module_container/ATen/kernels/lapack.h" // container::kernels
99

1010
#include "source_hsolver/diag_comm_info.h"
1111
#include "source_hsolver/kernels/hegvd_op.h"
@@ -341,6 +341,8 @@ class DiagoDavid
341341
using syncmem_h2d_op = base_device::memory::synchronize_memory_op<T, Device, base_device::DEVICE_CPU>;
342342
using syncmem_d2h_op = base_device::memory::synchronize_memory_op<T, base_device::DEVICE_CPU, Device>;
343343

344+
// Note that ct_Device is different from base_device!
345+
using ct_Device = typename ct::PsiToContainer<Device>::type;
344346
// using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info; // Dependence of hpsi removed
345347

346348
const T *one = nullptr, *zero = nullptr, *neg_one = nullptr;

0 commit comments

Comments
 (0)