Skip to content

Commit 590ea5e

Browse files
Remove Base class DiagH in LCAO code (#5239)
* Remove Base DiagH class of DiagoLapack * Remove DiagH Base class of DiagoScalapack and lcao test * Remove DiagH Base class of DiagoCusolver and lcao_cusolver_test * Update docs in hsolver * Remove Base DiagH pointer in diago_lapack_test * Remove Base DiagH of DiagoElpa * Remove Base DiagH of DiagoElpaNative * Remove Base DiagH of DiagoCusolverMP --------- Co-authored-by: Haozhi Han <[email protected]>
1 parent 04def35 commit 590ea5e

File tree

11 files changed

+62
-34
lines changed

11 files changed

+62
-34
lines changed

source/module_hsolver/diago_cusolver.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
namespace hsolver
1010
{
1111

12-
// DiagoCusolver class, derived from DiagH, for diagonalization using CUSOLVER
12+
// DiagoCusolver class for diagonalization using CUSOLVER
1313
template <typename T>
14-
class DiagoCusolver : public DiagH<T>
14+
class DiagoCusolver
1515
{
1616
private:
1717
// Real is the real part of the complex type T
@@ -24,7 +24,7 @@ class DiagoCusolver : public DiagH<T>
2424
~DiagoCusolver();
2525

2626
// Override the diag function for CUSOLVER diagonalization
27-
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
27+
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
2828

2929
// Static variable to keep track of the decomposition state
3030
static int DecomposedState;

source/module_hsolver/diago_cusolvermp.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace hsolver
1010
{
1111
// DiagoCusolverMP class, derived from DiagH, for diagonalization using CUSOLVERMP
1212
template <typename T>
13-
class DiagoCusolverMP : public DiagH<T>
13+
class DiagoCusolverMP
1414
{
1515
private:
1616
using Real = typename GetTypeReal<T>::type;
@@ -19,8 +19,8 @@ class DiagoCusolverMP : public DiagH<T>
1919
DiagoCusolverMP()
2020
{
2121
}
22-
// Override the diag function for CUSOLVERMP diagonalization
23-
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
22+
// the diag function for CUSOLVERMP diagonalization
23+
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
2424
};
2525
} // namespace hsolver
2626
#endif // __CUSOLVERMP

source/module_hsolver/diago_elpa.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ namespace hsolver
88
{
99

1010
template <typename T>
11-
class DiagoElpa : public DiagH<T>
11+
class DiagoElpa
1212
{
1313
private:
1414
using Real = typename GetTypeReal<T>::type;
1515

1616
public:
17-
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
17+
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
1818
#ifdef __MPI
1919
// diagnolization used in parallel-k case
20-
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm) override;
20+
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm);
2121
MPI_Comm setmpicomm(); // set mpi comm;
2222
static int elpa_num_thread; // need to set mpi_comm or not,-1 not,else the number of mpi needed
2323
#endif

source/module_hsolver/diago_elpa_native.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ namespace hsolver
88
{
99

1010
template <typename T>
11-
class DiagoElpaNative : public DiagH<T>
11+
class DiagoElpaNative
1212
{
1313
private:
1414
using Real = typename GetTypeReal<T>::type;
1515

1616
public:
17-
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
17+
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
1818
#ifdef __MPI
1919
// diagnolization used in parallel-k case
20-
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm) override;
20+
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm);
2121
MPI_Comm setmpicomm(); // set mpi comm;
2222
static int elpa_num_thread; // need to set mpi_comm or not,-1 not,else the number of mpi needed
2323
static int lastmpinum; // last using mpi;

source/module_hsolver/diago_lapack.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
namespace hsolver
2121
{
2222
template <typename T>
23-
class DiagoLapack : public DiagH<T>
23+
class DiagoLapack
2424
{
2525
private:
2626
using Real = typename GetTypeReal<T>::type;
2727

2828
public:
29-
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
29+
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
3030

3131
void dsygvx_diag(const int ncol,
3232
const int nrow,

source/module_hsolver/diago_scalapack.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
namespace hsolver
2121
{
2222
template<typename T>
23-
class DiagoScalapack : public DiagH<T>
23+
class DiagoScalapack
2424
{
2525
private:
2626
using Real = typename GetTypeReal<T>::type;
2727
public:
28-
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
28+
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
2929
#ifdef __MPI
3030
// diagnolization used in parallel-k case
31-
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm) override;
31+
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm);
3232
#endif
3333

3434
private:

source/module_hsolver/hsolver_lcao.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
169169
DiagoLapack<T> la;
170170
la.diag(hm, psi, eigenvalue);
171171
#else
172-
ModuleBase::WARNING_QUIT("HSolverLCAO::solve", "This method of DiagH is not supported!");
172+
ModuleBase::WARNING_QUIT("HSolverLCAO::solve", "This type of eigensolver is not supported!");
173173
#endif
174174
}
175175
else
@@ -379,7 +379,7 @@ void HSolverLCAO<T, Device>::parakSolve(hamilt::Hamilt<T>* pHamilt,
379379
else
380380
{
381381
ModuleBase::WARNING_QUIT("HSolverLCAO::solve",
382-
"This method of DiagH for k-parallelism diagnolization is not supported!");
382+
"This type of eigensolver for k-parallelism diagnolization is not supported!");
383383
}
384384
}
385385
MPI_Barrier(MPI_COMM_WORLD);

source/module_hsolver/hsolver_pw_sdft.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt<std::complex<double>>* pHamilt,
3333
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
3434
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
3535
{
36-
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
36+
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This type of eigensolver is not supported!");
3737
}
3838

3939
// part of KSDFT to get KS orbitals

source/module_hsolver/test/diago_lapack_test.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,15 @@ class DiagoLapackPrepare
119119
: nlocal(nlocal), nbands(nbands), nb2d(nb2d), sparsity(sparsity), hfname(hfname),
120120
sfname(sfname), solutionfname(solutionfname)
121121
{
122-
dh = new hsolver::DiagoLapack<T>;
122+
// dh = new hsolver::DiagoLapack<T>;
123123
}
124124

125125
int nlocal, nbands, nb2d, sparsity;
126126
std::string sfname, hfname, solutionfname;
127127
std::vector<T> h;
128128
std::vector<T> s;
129129
HamiltTEST<T> hmtest;
130-
hsolver::DiagH<T>* dh = nullptr;
130+
// hsolver::DiagH<T>* dh = nullptr;
131131
psi::Psi<T> psi;
132132
std::vector<double> e_solver;
133133
std::vector<double> e_lapack;
@@ -200,9 +200,11 @@ class DiagoLapackPrepare
200200

201201
for (int i = 0; i < REPEATRUN; i++)
202202
{
203-
dh->diag(&hmtest, psi, e_solver.data());
203+
hsolver::DiagoLapack<T> dh;
204+
dh.diag(&hmtest, psi, e_solver.data());
205+
// dh->diag(&hmtest, psi, e_solver.data());
204206
}
205-
delete dh;
207+
// delete dh;
206208
}
207209

208210
void read_SOLUTION()

source/module_hsolver/test/diago_lcao_cusolver_test.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ class DiagoPrepare
7474
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
7575

7676
if (ks_solver == "scalapack_gvx")
77-
dh = new hsolver::DiagoScalapack<T>;
77+
;//dh = new hsolver::DiagoScalapack<T>;
7878
#ifdef __CUDA
7979
else if (ks_solver == "cusolver")
80-
dh = new hsolver::DiagoCusolver<T>;
80+
;//dh = new hsolver::DiagoCusolver<T>;
8181
#endif
8282
else
8383
{
@@ -96,7 +96,7 @@ class DiagoPrepare
9696
std::vector<T> s;
9797
std::vector<T> h_local;
9898
std::vector<T> s_local;
99-
hsolver::DiagH<T>* dh = 0;
99+
// hsolver::DiagH<T>* dh = 0;
100100
psi::Psi<T> psi;
101101
std::vector<double> e_solver;
102102
std::vector<double> e_lapack;
@@ -222,11 +222,23 @@ class DiagoPrepare
222222
{
223223
hmtest.h_local = this->h_local;
224224
hmtest.s_local = this->s_local;
225-
dh->diag(&hmtest, psi, e_solver.data());
225+
if (ks_solver == "scalapack_gvx")
226+
{
227+
hsolver::DiagoScalapack<T> dh;
228+
dh.diag(&hmtest, psi, e_solver.data());
229+
}
230+
#ifdef __CUDA
231+
else if (ks_solver == "cusolver")
232+
{
233+
hsolver::DiagoCusolver<T> dh;
234+
dh.diag(&hmtest, psi, e_solver.data());
235+
}
236+
#endif
237+
// dh->diag(&hmtest, psi, e_solver.data());
226238
}
227239
endtime = MPI_Wtime();
228240
hsolver_time = (endtime - starttime) / REPEATRUN;
229-
delete dh;
241+
// delete dh;
230242
}
231243

232244
void diago_lapack()

0 commit comments

Comments
 (0)