Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions python/pyabacus/src/hsolver/py_diago_cg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ class PyDiagoCG
);
}

void diag(
std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
int diag_ndim,
double tol,
bool need_subspace,
bool scf_type,
int nproc_in_pool = 1
void diag(std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
int diag_ndim,
double tol,
const std::vector<double>& diag_ethr,
bool need_subspace,
bool scf_type,
int nproc_in_pool = 1
) {
const std::string basis_type = "pw";
const std::string calculation = scf_type ? "scf" : "nscf";
Expand Down Expand Up @@ -171,7 +171,7 @@ class PyDiagoCG
nproc_in_pool
);

cg->diag(hpsi_func, spsi_func, *psi, *eig, *prec);
cg->diag(hpsi_func, spsi_func, *psi, *eig, diag_ethr, *prec);
}

private:
Expand Down
5 changes: 4 additions & 1 deletion python/pyabacus/src/hsolver/py_hsolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ void bind_hsolver(py::module& m)
for invoking this class is a function defined in _hsolver.py,
which uses this class to perform the calculations.
)pbdoc")
.def("diag", &py_hsolver::PyDiagoCG::diag, R"pbdoc(
.def("diag",
&py_hsolver::PyDiagoCG::diag,
R"pbdoc(
Diagonalize the linear operator using the Conjugate Gradient Method.

Parameters
Expand All @@ -179,6 +181,7 @@ void bind_hsolver(py::module& m)
"mm_op"_a,
"max_iter"_a,
"tol"_a,
"diag_ethr"_a,
"need_subspace"_a,
"scf_type"_a,
"nproc_in_pool"_a)
Expand Down
5 changes: 5 additions & 0 deletions python/pyabacus/src/pyabacus/hsolver/_hsolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def cg(
precondition: NDArray[np.float64],
tol: float = 1e-2,
max_iter: int = 1000,
diag_ethr: Union[List[float], None] = None,
need_subspace: bool = False,
scf_type: bool = False,
nproc_in_pool: int = 1
Expand Down Expand Up @@ -244,6 +245,9 @@ def cg(
if init_v.ndim == 2:
init_v = init_v.T
init_v = init_v.flatten().astype(np.complex128, order='C')

if diag_ethr is None:
diag_ethr = [tol] * num_eigs

_diago_obj_cg = diago_cg(dim, num_eigs)
_diago_obj_cg.set_psi(init_v)
Expand All @@ -255,6 +259,7 @@ def cg(
mvv_op,
max_iter,
tol,
diag_ethr,
need_subspace,
scf_type,
nproc_in_pool
Expand Down
12 changes: 9 additions & 3 deletions source/module_hsolver/diago_cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ DiagoCG<T, Device>::~DiagoCG()
}

template <typename T, typename Device>
void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in, ct::Tensor& psi, ct::Tensor& eigen)
void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in,
ct::Tensor& psi,
ct::Tensor& eigen,
const std::vector<double>& ethr_band)
{
ModuleBase::TITLE("DiagoCG", "diag_once");
ModuleBase::timer::tick("DiagoCG", "diag_once");
Expand Down Expand Up @@ -153,6 +156,7 @@ void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in, ct::Tensor& psi, c
converged = this->update_psi(pphi,
cg,
scg, // const Tensor&
ethr_band[m],
cg_norm,
theta,
eigen_pack[m], // Real&
Expand Down Expand Up @@ -392,6 +396,7 @@ template <typename T, typename Device>
bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
const ct::Tensor& cg,
const ct::Tensor& scg,
const double& ethreshold,
Real& cg_norm,
Real& theta,
Real& eigen,
Expand Down Expand Up @@ -441,7 +446,7 @@ bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
cg.data<T>(),
sint_norm);

if (std::abs(eigen - e0) < pw_diag_thr_)
if (std::abs(eigen - e0) < ethreshold)
{
// ModuleBase::timer::tick("DiagoCG","update");
return true;
Expand Down Expand Up @@ -582,6 +587,7 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
const Func& spsi_func,
ct::Tensor& psi,
ct::Tensor& eigen,
const std::vector<double>& ethr_band,
const ct::Tensor& prec)
{
/// record the times of trying iterative diagonalization
Expand All @@ -603,7 +609,7 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,

++ntry;
avg_iter_ += 1.0;
this->diag_mock(prec, psi_temp, eigen);
this->diag_mock(prec, psi_temp, eigen, ethr_band);
} while (this->test_exit_cond(ntry, this->notconv_));

if (this->notconv_ > std::max(5, this->n_band_ / 4))
Expand Down
13 changes: 11 additions & 2 deletions source/module_hsolver/diago_cg.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ class DiagoCG final
// virtual void init(){};
// refactor hpsi_info
// this is the diag() function for CG method
void diag(const Func& hpsi_func, const Func& spsi_func, ct::Tensor& psi, ct::Tensor& eigen, const ct::Tensor& prec = {});
void diag(const Func& hpsi_func,
const Func& spsi_func,
ct::Tensor& psi,
ct::Tensor& eigen,
const std::vector<double>& ethr_band,
const ct::Tensor& prec = {});

private:
Device * ctx_ = {};
Expand Down Expand Up @@ -103,6 +108,7 @@ class DiagoCG final
const ct::Tensor& pphi,
const ct::Tensor& cg,
const ct::Tensor& scg,
const double& ethreshold,
Real &cg_norm,
Real &theta,
Real &eigen,
Expand All @@ -113,7 +119,10 @@ class DiagoCG final
void schmit_orth(const int& m, const ct::Tensor& psi, const ct::Tensor& sphi, ct::Tensor& phi_m);

// used in diag() for template replace Hamilt with Hamilt_PW
void diag_mock(const ct::Tensor& prec, ct::Tensor& psi, ct::Tensor& eigen);
void diag_mock(const ct::Tensor& prec,
ct::Tensor& psi,
ct::Tensor& eigen,
const std::vector<double>& ethr_band);

bool test_exit_cond(const int& ntry, const int& notconv) const;

Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
// prepare for the precondition of diagonalization
std::vector<Real> precondition(psi.get_nbasis(), 0.0);
std::vector<Real> eigenvalues(this->wfc_basis->nks * psi.get_nbands(), 0.0);
ethr_band.resize(psi.get_nbands(), DiagoIterAssist<T, Device>::PW_DIAG_THR);
ethr_band.resize(psi.get_nbands(), this->diag_thr);

/// Loop over k points for solve Hamiltonian to charge density
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
Expand Down Expand Up @@ -467,7 +467,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
.to_device<ct_Device>()
.slice({0}, {psi.get_current_nbas()});

cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor);
// TODO: Double check tensormap's potential problem
ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);
}
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ class HSolverPW
int rank_in_pool = 0;
int nproc_in_pool = 1;

std::vector<double> ethr_band;

private:
/// @brief calculate the threshold for iterative-diagonalization for each band
void cal_ethr_band(const double& wk, const double* wg, const double& ethr, std::vector<double>& ethrs);

std::vector<double> ethr_band;

#ifdef USE_PAW
void paw_func_in_kloop(const int ik,
const double tpiba);
Expand Down
26 changes: 2 additions & 24 deletions source/module_hsolver/hsolver_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,11 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,
const int nbands = psi.get_nbands();
const int nks = psi.get_nk();

//---------------------------------------------------------------------------------------------------------------
//---------------------------------for psi init guess!!!!--------------------------------------------------------
//---------------------------------------------------------------------------------------------------------------
// if (!PARAM.inp.psi_initializer && !this->initialed_psi && this->basis_type == "pw")
// {
// for (int ik = 0; ik < nks; ++ik)
// {
// /// update H(k) for each k point
// pHamilt->updateHk(ik);

// if (nbands > 0 && GlobalV::MY_STOGROUP == 0)
// {
// /// update psi pointer for each k point
// psi.fix_k(ik);

// /// for psi init guess!!!!
// hamilt::diago_PAO_in_pw_k2(this->ctx, ik, psi, this->wfc_basis, this->pwf, pHamilt);
// }
// }
// }
//---------------------------------------------------------------------------------------------------------------
//---------------------------------------------------------------------------------------------------------------
//---------------------------------------------------------------------------------------------------------------

// prepare for the precondition of diagonalization
std::vector<double> precondition(psi.get_nbasis(), 0.0);

this->ethr_band.resize(psi.get_nbands(), this->diag_thr);

// report if the specified diagonalization method is not supported
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
Expand Down
5 changes: 3 additions & 2 deletions source/module_hsolver/test/diago_cg_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ class DiagoCGPrepare
ct::DataType::DT_FLOAT,
ct::DeviceType::CpuDevice,
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});

cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);

std::vector<double> ethr_band(nband, 1e-5);
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);
// TODO: Double check tensormap's potential problem
ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor);
/**************************************************************/
Expand Down
5 changes: 3 additions & 2 deletions source/module_hsolver/test/diago_cg_real_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ class DiagoCGPrepare
ct::DataType::DT_DOUBLE,
ct::DeviceType::CpuDevice,
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});

cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);

std::vector<double> ethr_band(nband, 1e-5);
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);
// TODO: Double check tensormap's potential problem
ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor);
/**************************************************************/
Expand Down
5 changes: 3 additions & 2 deletions source/module_hsolver/test/diago_cg_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,9 @@ class DiagoCGPrepare
ct::DataType::DT_DOUBLE,
ct::DeviceType::CpuDevice,
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});

cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);

std::vector<double> ethr_band(nband, 1e-5);
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);
// TODO: Double check tensormap's potential problem
ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor);
/**************************************************************/
Expand Down
1 change: 1 addition & 0 deletions source/module_hsolver/test/hsolver_pw_sup.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
const Func& spsi_func,
ct::Tensor& psi,
ct::Tensor& eigen,
const std::vector<double>& ethr_band,
const ct::Tensor& prec) {
auto n_bands = psi.shape().dim_size(0);
auto n_basis = psi.shape().dim_size(1);
Expand Down
9 changes: 6 additions & 3 deletions source/module_lr/hsolver_lrtd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,12 @@ namespace LR
std::vector<Real<T>> precondition_(precondition); //since TensorMap does not support const pointer
auto precon_tensor = ct::TensorMap(precondition_.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ dim }));
auto hpsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& hpsi) {hm.hPsi(psi_in.data<T>(), hpsi.data<T>(), psi_in.shape().dim_size(0) /*nbasis_local*/, 1/*band-by-band*/);};
auto spsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& spsi)
{ std::memcpy(spsi.data<T>(), psi_in.data<T>(), sizeof(T) * psi_in.NumElements()); };
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, precon_tensor);
auto spsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& spsi) {
std::memcpy(spsi.data<T>(), psi_in.data<T>(), sizeof(T) * psi_in.NumElements());
};

std::vector<double> ethr_band(nband, diag_ethr);
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, precon_tensor);
}
else { throw std::runtime_error("HSolverLR::solve: method not implemented"); }
}
Expand Down
Loading