Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_cusolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
namespace hsolver
{

// DiagoCusolver class, derived from DiagH, for diagonalization using CUSOLVER
// DiagoCusolver class for diagonalization using CUSOLVER
template <typename T>
class DiagoCusolver : public DiagH<T>
class DiagoCusolver
{
private:
// Real is the real part of the complex type T
Expand All @@ -24,7 +24,7 @@ class DiagoCusolver : public DiagH<T>
~DiagoCusolver();

// Override the diag function for CUSOLVER diagonalization
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);

// Static variable to keep track of the decomposition state
static int DecomposedState;
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_cusolvermp.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace hsolver
{
// DiagoCusolverMP class, derived from DiagH, for diagonalization using CUSOLVERMP
template <typename T>
class DiagoCusolverMP : public DiagH<T>
class DiagoCusolverMP
{
private:
using Real = typename GetTypeReal<T>::type;
Expand All @@ -19,8 +19,8 @@ class DiagoCusolverMP : public DiagH<T>
DiagoCusolverMP()
{
}
// Override the diag function for CUSOLVERMP diagonalization
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
// the diag function for CUSOLVERMP diagonalization
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
};
} // namespace hsolver
#endif // __CUSOLVERMP
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_elpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ namespace hsolver
{

template <typename T>
class DiagoElpa : public DiagH<T>
class DiagoElpa
{
private:
using Real = typename GetTypeReal<T>::type;

public:
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
#ifdef __MPI
// diagnolization used in parallel-k case
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm) override;
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm);
MPI_Comm setmpicomm(); // set mpi comm;
static int elpa_num_thread; // need to set mpi_comm or not,-1 not,else the number of mpi needed
#endif
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_elpa_native.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ namespace hsolver
{

template <typename T>
class DiagoElpaNative : public DiagH<T>
class DiagoElpaNative
{
private:
using Real = typename GetTypeReal<T>::type;

public:
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
#ifdef __MPI
// diagnolization used in parallel-k case
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm) override;
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm);
MPI_Comm setmpicomm(); // set mpi comm;
static int elpa_num_thread; // need to set mpi_comm or not,-1 not,else the number of mpi needed
static int lastmpinum; // last using mpi;
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_lapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
namespace hsolver
{
template <typename T>
class DiagoLapack : public DiagH<T>
class DiagoLapack
{
private:
using Real = typename GetTypeReal<T>::type;

public:
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);

void dsygvx_diag(const int ncol,
const int nrow,
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_scalapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
namespace hsolver
{
template<typename T>
class DiagoScalapack : public DiagH<T>
class DiagoScalapack
{
private:
using Real = typename GetTypeReal<T>::type;
public:
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in) override;
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
#ifdef __MPI
// diagnolization used in parallel-k case
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm) override;
void diag_pool(hamilt::MatrixBlock<T>& h_mat, hamilt::MatrixBlock<T>& s_mat, psi::Psi<T>& psi, Real* eigenvalue_in, MPI_Comm& comm);
#endif

private:
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/hsolver_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
DiagoLapack<T> la;
la.diag(hm, psi, eigenvalue);
#else
ModuleBase::WARNING_QUIT("HSolverLCAO::solve", "This method of DiagH is not supported!");
ModuleBase::WARNING_QUIT("HSolverLCAO::solve", "This type of eigensolver is not supported!");
#endif
}
else
Expand Down Expand Up @@ -379,7 +379,7 @@ void HSolverLCAO<T, Device>::parakSolve(hamilt::Hamilt<T>* pHamilt,
else
{
ModuleBase::WARNING_QUIT("HSolverLCAO::solve",
"This method of DiagH for k-parallelism diagnolization is not supported!");
"This type of eigensolver for k-parallelism diagnolization is not supported!");
}
}
MPI_Barrier(MPI_COMM_WORLD);
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/hsolver_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt<std::complex<double>>* pHamilt,
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
{
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This type of eigensolver is not supported!");
}

// part of KSDFT to get KS orbitals
Expand Down
10 changes: 6 additions & 4 deletions source/module_hsolver/test/diago_lapack_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ class DiagoLapackPrepare
: nlocal(nlocal), nbands(nbands), nb2d(nb2d), sparsity(sparsity), hfname(hfname),
sfname(sfname), solutionfname(solutionfname)
{
dh = new hsolver::DiagoLapack<T>;
// dh = new hsolver::DiagoLapack<T>;
}

int nlocal, nbands, nb2d, sparsity;
std::string sfname, hfname, solutionfname;
std::vector<T> h;
std::vector<T> s;
HamiltTEST<T> hmtest;
hsolver::DiagH<T>* dh = nullptr;
// hsolver::DiagH<T>* dh = nullptr;
psi::Psi<T> psi;
std::vector<double> e_solver;
std::vector<double> e_lapack;
Expand Down Expand Up @@ -200,9 +200,11 @@ class DiagoLapackPrepare

for (int i = 0; i < REPEATRUN; i++)
{
dh->diag(&hmtest, psi, e_solver.data());
hsolver::DiagoLapack<T> dh;
dh.diag(&hmtest, psi, e_solver.data());
// dh->diag(&hmtest, psi, e_solver.data());
}
delete dh;
// delete dh;
}

void read_SOLUTION()
Expand Down
22 changes: 17 additions & 5 deletions source/module_hsolver/test/diago_lcao_cusolver_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ class DiagoPrepare
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);

if (ks_solver == "scalapack_gvx")
dh = new hsolver::DiagoScalapack<T>;
;//dh = new hsolver::DiagoScalapack<T>;
#ifdef __CUDA
else if (ks_solver == "cusolver")
dh = new hsolver::DiagoCusolver<T>;
;//dh = new hsolver::DiagoCusolver<T>;
#endif
else
{
Expand All @@ -96,7 +96,7 @@ class DiagoPrepare
std::vector<T> s;
std::vector<T> h_local;
std::vector<T> s_local;
hsolver::DiagH<T>* dh = 0;
// hsolver::DiagH<T>* dh = 0;
psi::Psi<T> psi;
std::vector<double> e_solver;
std::vector<double> e_lapack;
Expand Down Expand Up @@ -222,11 +222,23 @@ class DiagoPrepare
{
hmtest.h_local = this->h_local;
hmtest.s_local = this->s_local;
dh->diag(&hmtest, psi, e_solver.data());
if (ks_solver == "scalapack_gvx")
{
hsolver::DiagoScalapack<T> dh;
dh.diag(&hmtest, psi, e_solver.data());
}
#ifdef __CUDA
else if (ks_solver == "cusolver")
{
hsolver::DiagoCusolver<T> dh;
dh.diag(&hmtest, psi, e_solver.data());
}
#endif
// dh->diag(&hmtest, psi, e_solver.data());
}
endtime = MPI_Wtime();
hsolver_time = (endtime - starttime) / REPEATRUN;
delete dh;
// delete dh;
}

void diago_lapack()
Expand Down
24 changes: 19 additions & 5 deletions source/module_hsolver/test/diago_lcao_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ class DiagoPrepare
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);

if (ks_solver == "scalapack_gvx")
dh = new hsolver::DiagoScalapack<T>;
;
// dh = new hsolver::DiagoScalapack<T>;
#ifdef __ELPA
else if (ks_solver == "genelpa")
dh = new hsolver::DiagoElpa<T>;
;
// dh = new hsolver::DiagoElpa<T>;
#endif
else
{
Expand All @@ -93,7 +95,7 @@ class DiagoPrepare
std::vector<T> s;
std::vector<T> h_local;
std::vector<T> s_local;
hsolver::DiagH<T>* dh = 0;
// hsolver::DiagH<T>* dh = 0;
psi::Psi<T> psi;
std::vector<double> e_solver;
std::vector<double> e_lapack;
Expand Down Expand Up @@ -219,11 +221,23 @@ class DiagoPrepare
{
hmtest.h_local = this->h_local;
hmtest.s_local = this->s_local;
dh->diag(&hmtest, psi, e_solver.data());
if (ks_solver == "scalapack_gvx")
{
hsolver::DiagoScalapack<T> dh;
dh.diag(&hmtest, psi, e_solver.data());
}
#ifdef __ELPA
else if (ks_solver == "genelpa")
{
hsolver::DiagoElpa<T> dh;
dh.diag(&hmtest, psi, e_solver.data());
}
#endif
// dh.diag(&hmtest, psi, e_solver.data());
}
endtime = MPI_Wtime();
hsolver_time = (endtime - starttime) / REPEATRUN;
delete dh;
// delete dh;
}

void diago_lapack()
Expand Down
Loading