Skip to content

Commit e52ff4d

Browse files
committed
Remove DiagH Base class of DiagoCusolver and lcao_cusolver_test
1 parent 47557b7 commit e52ff4d

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

source/module_hsolver/diago_cusolver.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
namespace hsolver
1010
{
1111

12-
// DiagoCusolver class, derived from DiagH, for diagonalization using CUSOLVER
12+
// DiagoCusolver class for diagonalization using CUSOLVER
1313
template <typename T>
14-
class DiagoCusolver : public DiagH<T>
14+
class DiagoCusolver
1515
{
1616
private:
1717
// Real is the real part of the complex type T
@@ -24,7 +24,7 @@ class DiagoCusolver : public DiagH<T>
2424
~DiagoCusolver();
2525

2626
// Override the diag function for CUSOLVER diagonalization
27-
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
27+
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
2828

2929
// Static variable to keep track of the decomposition state
3030
static int DecomposedState;

source/module_hsolver/test/diago_lcao_cusolver_test.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ class DiagoPrepare
7474
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
7575

7676
if (ks_solver == "scalapack_gvx")
77-
dh = new hsolver::DiagoScalapack<T>;
77+
;//dh = new hsolver::DiagoScalapack<T>;
7878
#ifdef __CUDA
7979
else if (ks_solver == "cusolver")
80-
dh = new hsolver::DiagoCusolver<T>;
80+
;//dh = new hsolver::DiagoCusolver<T>;
8181
#endif
8282
else
8383
{
@@ -96,7 +96,7 @@ class DiagoPrepare
9696
std::vector<T> s;
9797
std::vector<T> h_local;
9898
std::vector<T> s_local;
99-
hsolver::DiagH<T>* dh = 0;
99+
// hsolver::DiagH<T>* dh = 0;
100100
psi::Psi<T> psi;
101101
std::vector<double> e_solver;
102102
std::vector<double> e_lapack;
@@ -222,11 +222,23 @@ class DiagoPrepare
222222
{
223223
hmtest.h_local = this->h_local;
224224
hmtest.s_local = this->s_local;
225-
dh->diag(&hmtest, psi, e_solver.data());
225+
if (ks_solver == "scalapack_gvx")
226+
{
227+
hsolver::DiagoScalapack<T> dh;
228+
dh.diag(&hmtest, psi, e_solver.data());
229+
}
230+
#ifdef __CUDA
231+
else if (ks_solver == "cusolver")
232+
{
233+
hsolver::DiagoCusolver<T> dh;
234+
dh.diag(&hmtest, psi, e_solver.data());
235+
}
236+
#endif
237+
// dh->diag(&hmtest, psi, e_solver.data());
226238
}
227239
endtime = MPI_Wtime();
228240
hsolver_time = (endtime - starttime) / REPEATRUN;
229-
delete dh;
241+
// delete dh;
230242
}
231243

232244
void diago_lapack()

0 commit comments

Comments
 (0)