diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index 940a3aeb5a..71040960c3 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -24,7 +24,7 @@ namespace hsolver { * @tparam Device The device used for calculations (e.g., cpu or gpu). */ template , typename Device = base_device::DEVICE_CPU> -class DiagoBPCG +class DiagoBPCG : public DiagH { private: // Note GetTypeReal::type will @@ -56,15 +56,15 @@ class DiagoBPCG void init_iter(const psi::Psi &psi_in); /** - * @brief Diagonalize the Hamiltonian using the BPCG method. + * @brief Diagonalize the Hamiltonian using the CG method. * - * This function is called by the HsolverPW::solve() function. + * This function is an override function for the CG method. It is called by the HsolverPW::solve() function. * * @param phm_in A pointer to the hamilt::Hamilt object representing the Hamiltonian operator. * @param psi The input wavefunction psi matrix with [dim: n_basis x n_band, column major]. * @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major]. */ - void diag(hamilt::Hamilt *phm_in, psi::Psi &psi, Real *eigenvalue_in); + void diag(hamilt::Hamilt *phm_in, psi::Psi &psi, Real *eigenvalue_in) override; private: diff --git a/source/module_hsolver/diago_cg.h b/source/module_hsolver/diago_cg.h index 1b64cfe4e8..bb106cd8c1 100644 --- a/source/module_hsolver/diago_cg.h +++ b/source/module_hsolver/diago_cg.h @@ -13,7 +13,7 @@ namespace hsolver { template -class DiagoCG final +class DiagoCG final : public DiagH { // private: accessibility within class is private by default // Note GetTypeReal::type will @@ -36,11 +36,11 @@ class DiagoCG final const int& pw_diag_nmax, const int& nproc_in_pool); - ~DiagoCG(); + ~DiagoCG() override; // virtual void init(){}; // refactor hpsi_info - // this is the diag() function for CG method + // this is the override function diag() for CG method void diag(const Func& hpsi_func, const Func& spsi_func, ct::Tensor& psi, ct::Tensor& eigen, const ct::Tensor& prec = {}); private: diff --git a/source/module_hsolver/diago_dav_subspace.h b/source/module_hsolver/diago_dav_subspace.h index 06d66aa8d2..8e26a08b45 100644 --- a/source/module_hsolver/diago_dav_subspace.h +++ b/source/module_hsolver/diago_dav_subspace.h @@ -11,7 +11,7 @@ namespace hsolver { template , typename Device = base_device::DEVICE_CPU> -class Diago_DavSubspace +class Diago_DavSubspace : public DiagH { private: // Note GetTypeReal::type will @@ -29,7 +29,7 @@ class Diago_DavSubspace const bool& need_subspace_in, const diag_comm_info& diag_comm_in); - ~Diago_DavSubspace(); + virtual ~Diago_DavSubspace() override; // See diago_david.h for information on the HPsiFunc function type using HPsiFunc = std::function; diff --git a/source/module_hsolver/diago_david.h b/source/module_hsolver/diago_david.h index 576e36eed4..2a59eb2134 100644 --- a/source/module_hsolver/diago_david.h +++ b/source/module_hsolver/diago_david.h @@ -8,7 +8,7 @@ namespace hsolver { template , typename Device = base_device::DEVICE_CPU> -class DiagoDavid +class DiagoDavid : public DiagH { private: // Note GetTypeReal::type will @@ -25,7 +25,7 @@ class DiagoDavid const bool use_paw_in, const diag_comm_info& diag_comm_in); - ~DiagoDavid(); + virtual ~DiagoDavid() override; // declare type of matrix-blockvector functions. diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 5ef07216c3..7e4bd3349d 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -228,7 +228,7 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, const std::initializer_list _methods = {"cg", "dav", "dav_subspace", "bpcg"}; if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods)) { - ModuleBase::WARNING_QUIT("HSolverPW::solve", "This type of eigensolver is not supported!"); + ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!"); } // prepare for the precondition of diagonalization diff --git a/source/module_hsolver/test/test_hsolver.cpp b/source/module_hsolver/test/test_hsolver.cpp index 4f5adb96ac..80ea102f97 100644 --- a/source/module_hsolver/test/test_hsolver.cpp +++ b/source/module_hsolver/test/test_hsolver.cpp @@ -83,4 +83,28 @@ class TestHSolver : public ::testing::Test // double test_diagethr_d = hs_d.set_diagethr(0.0, 0, 0, 0.0); // EXPECT_EQ(test_diagethr_d, 0.0); -// } \ No newline at end of file +// } +namespace hsolver +{ +template +class DiagH_mock : public DiagH +{ + private: + using Real = typename GetTypeReal::type; + + public: + DiagH_mock() + { + } + ~DiagH_mock() + { + } + + void diag(hamilt::Hamilt* phm_in, psi::Psi& psi, Real* eigenvalue_in) + { + return; + } + }; + template class DiagH_mock>; + template class DiagH_mock>; +} \ No newline at end of file