Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions source/module_hsolver/diago_bpcg.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace hsolver {
* @tparam Device The device used for calculations (e.g., cpu or gpu).
*/
template <typename T = std::complex<double>, typename Device = base_device::DEVICE_CPU>
class DiagoBPCG
class DiagoBPCG : public DiagH<T, Device>
{
private:
// Note GetTypeReal<T>::type will
Expand Down Expand Up @@ -56,15 +56,15 @@ class DiagoBPCG
void init_iter(const psi::Psi<T, Device> &psi_in);

/**
* @brief Diagonalize the Hamiltonian using the BPCG method.
* @brief Diagonalize the Hamiltonian using the CG method.
*
* This function is called by the HsolverPW::solve() function.
* This function is an override function for the CG method. It is called by the HsolverPW::solve() function.
*
* @param phm_in A pointer to the hamilt::Hamilt object representing the Hamiltonian operator.
* @param psi The input wavefunction psi matrix with [dim: n_basis x n_band, column major].
* @param eigenvalue_in Pointer to the eigen array with [dim: n_band, column major].
*/
void diag(hamilt::Hamilt<T, Device> *phm_in, psi::Psi<T, Device> &psi, Real *eigenvalue_in);
void diag(hamilt::Hamilt<T, Device> *phm_in, psi::Psi<T, Device> &psi, Real *eigenvalue_in) override;


private:
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_cg.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace hsolver {

template <typename T, typename Device = base_device::DEVICE_CPU>
class DiagoCG final
class DiagoCG final : public DiagH<T, Device>
{
// private: accessibility within class is private by default
// Note GetTypeReal<T>::type will
Expand All @@ -36,11 +36,11 @@ class DiagoCG final
const int& pw_diag_nmax,
const int& nproc_in_pool);

~DiagoCG();
~DiagoCG() override;

// virtual void init(){};
// refactor hpsi_info
// this is the diag() function for CG method
// this is the override function diag() for CG method
void diag(const Func& hpsi_func, const Func& spsi_func, ct::Tensor& psi, ct::Tensor& eigen, const ct::Tensor& prec = {});

private:
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace hsolver
{

template <typename T = std::complex<double>, typename Device = base_device::DEVICE_CPU>
class Diago_DavSubspace
class Diago_DavSubspace : public DiagH<T, Device>
{
private:
// Note GetTypeReal<T>::type will
Expand All @@ -29,7 +29,7 @@ class Diago_DavSubspace
const bool& need_subspace_in,
const diag_comm_info& diag_comm_in);

~Diago_DavSubspace();
virtual ~Diago_DavSubspace() override;

// See diago_david.h for information on the HPsiFunc function type
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace hsolver
{

template <typename T = std::complex<double>, typename Device = base_device::DEVICE_CPU>
class DiagoDavid
class DiagoDavid : public DiagH<T, Device>
{
private:
// Note GetTypeReal<T>::type will
Expand All @@ -25,7 +25,7 @@ class DiagoDavid
const bool use_paw_in,
const diag_comm_info& diag_comm_in);

~DiagoDavid();
virtual ~DiagoDavid() override;


// declare type of matrix-blockvector functions.
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
{
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This type of eigensolver is not supported!");
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
}

// prepare for the precondition of diagonalization
Expand Down
26 changes: 25 additions & 1 deletion source/module_hsolver/test/test_hsolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,28 @@ class TestHSolver : public ::testing::Test

// double test_diagethr_d = hs_d.set_diagethr(0.0, 0, 0, 0.0);
// EXPECT_EQ(test_diagethr_d, 0.0);
// }
// }
namespace hsolver
{
template <typename T, typename Device = base_device::DEVICE_CPU>
class DiagH_mock : public DiagH<T, Device>
{
private:
using Real = typename GetTypeReal<T>::type;

public:
DiagH_mock()
{
}
~DiagH_mock()
{
}

void diag(hamilt::Hamilt<T, Device>* phm_in, psi::Psi<T, Device>& psi, Real* eigenvalue_in)
{
return;
}
};
template class DiagH_mock<std::complex<float>>;
template class DiagH_mock<std::complex<double>>;
}
Loading