Skip to content

Commit 7087829

Browse files
haozhihanFisherd99
authored andcommitted
Refactor: add smooth threshold support for cg method (deepmodeling#5713)
* add smooth threshold for cg * update pyabacus * fix test build nbug * fix pyabacus bug * fix pyabacus bug * fix pyabacus bug * fix sdft bug * fix bug * fix bug * fix bug
1 parent 6c23342 commit 7087829

File tree

13 files changed

+59
-51
lines changed

13 files changed

+59
-51
lines changed

python/pyabacus/src/hsolver/py_diago_cg.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,13 @@ class PyDiagoCG
114114
);
115115
}
116116

117-
void diag(
118-
std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
119-
int diag_ndim,
120-
double tol,
121-
bool need_subspace,
122-
bool scf_type,
123-
int nproc_in_pool = 1
117+
void diag(std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
118+
int diag_ndim,
119+
double tol,
120+
const std::vector<double>& diag_ethr,
121+
bool need_subspace,
122+
bool scf_type,
123+
int nproc_in_pool = 1
124124
) {
125125
const std::string basis_type = "pw";
126126
const std::string calculation = scf_type ? "scf" : "nscf";
@@ -171,7 +171,7 @@ class PyDiagoCG
171171
nproc_in_pool
172172
);
173173

174-
cg->diag(hpsi_func, spsi_func, *psi, *eig, *prec);
174+
cg->diag(hpsi_func, spsi_func, *psi, *eig, diag_ethr, *prec);
175175
}
176176

177177
private:

python/pyabacus/src/hsolver/py_hsolver.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ void bind_hsolver(py::module& m)
158158
for invoking this class is a function defined in _hsolver.py,
159159
which uses this class to perform the calculations.
160160
)pbdoc")
161-
.def("diag", &py_hsolver::PyDiagoCG::diag, R"pbdoc(
161+
.def("diag",
162+
&py_hsolver::PyDiagoCG::diag,
163+
R"pbdoc(
162164
Diagonalize the linear operator using the Conjugate Gradient Method.
163165
164166
Parameters
@@ -179,6 +181,7 @@ void bind_hsolver(py::module& m)
179181
"mm_op"_a,
180182
"max_iter"_a,
181183
"tol"_a,
184+
"diag_ethr"_a,
182185
"need_subspace"_a,
183186
"scf_type"_a,
184187
"nproc_in_pool"_a)

python/pyabacus/src/pyabacus/hsolver/_hsolver.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def cg(
195195
precondition: NDArray[np.float64],
196196
tol: float = 1e-2,
197197
max_iter: int = 1000,
198+
diag_ethr: Union[List[float], None] = None,
198199
need_subspace: bool = False,
199200
scf_type: bool = False,
200201
nproc_in_pool: int = 1
@@ -244,6 +245,9 @@ def cg(
244245
if init_v.ndim == 2:
245246
init_v = init_v.T
246247
init_v = init_v.flatten().astype(np.complex128, order='C')
248+
249+
if diag_ethr is None:
250+
diag_ethr = [tol] * num_eigs
247251

248252
_diago_obj_cg = diago_cg(dim, num_eigs)
249253
_diago_obj_cg.set_psi(init_v)
@@ -255,6 +259,7 @@ def cg(
255259
mvv_op,
256260
max_iter,
257261
tol,
262+
diag_ethr,
258263
need_subspace,
259264
scf_type,
260265
nproc_in_pool

source/module_hsolver/diago_cg.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ DiagoCG<T, Device>::~DiagoCG()
5454
}
5555

5656
template <typename T, typename Device>
57-
void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in, ct::Tensor& psi, ct::Tensor& eigen)
57+
void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in,
58+
ct::Tensor& psi,
59+
ct::Tensor& eigen,
60+
const std::vector<double>& ethr_band)
5861
{
5962
ModuleBase::TITLE("DiagoCG", "diag_once");
6063
ModuleBase::timer::tick("DiagoCG", "diag_once");
@@ -153,6 +156,7 @@ void DiagoCG<T, Device>::diag_mock(const ct::Tensor& prec_in, ct::Tensor& psi, c
153156
converged = this->update_psi(pphi,
154157
cg,
155158
scg, // const Tensor&
159+
ethr_band[m],
156160
cg_norm,
157161
theta,
158162
eigen_pack[m], // Real&
@@ -392,6 +396,7 @@ template <typename T, typename Device>
392396
bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
393397
const ct::Tensor& cg,
394398
const ct::Tensor& scg,
399+
const double& ethreshold,
395400
Real& cg_norm,
396401
Real& theta,
397402
Real& eigen,
@@ -441,7 +446,7 @@ bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
441446
cg.data<T>(),
442447
sint_norm);
443448

444-
if (std::abs(eigen - e0) < pw_diag_thr_)
449+
if (std::abs(eigen - e0) < ethreshold)
445450
{
446451
// ModuleBase::timer::tick("DiagoCG","update");
447452
return true;
@@ -582,6 +587,7 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
582587
const Func& spsi_func,
583588
ct::Tensor& psi,
584589
ct::Tensor& eigen,
590+
const std::vector<double>& ethr_band,
585591
const ct::Tensor& prec)
586592
{
587593
/// record the times of trying iterative diagonalization
@@ -603,7 +609,7 @@ void DiagoCG<T, Device>::diag(const Func& hpsi_func,
603609

604610
++ntry;
605611
avg_iter_ += 1.0;
606-
this->diag_mock(prec, psi_temp, eigen);
612+
this->diag_mock(prec, psi_temp, eigen, ethr_band);
607613
} while (this->test_exit_cond(ntry, this->notconv_));
608614

609615
if (this->notconv_ > std::max(5, this->n_band_ / 4))

source/module_hsolver/diago_cg.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ class DiagoCG final
4040
// virtual void init(){};
4141
// refactor hpsi_info
4242
// this is the diag() function for CG method
43-
void diag(const Func& hpsi_func, const Func& spsi_func, ct::Tensor& psi, ct::Tensor& eigen, const ct::Tensor& prec = {});
43+
void diag(const Func& hpsi_func,
44+
const Func& spsi_func,
45+
ct::Tensor& psi,
46+
ct::Tensor& eigen,
47+
const std::vector<double>& ethr_band,
48+
const ct::Tensor& prec = {});
4449

4550
private:
4651
Device * ctx_ = {};
@@ -103,6 +108,7 @@ class DiagoCG final
103108
const ct::Tensor& pphi,
104109
const ct::Tensor& cg,
105110
const ct::Tensor& scg,
111+
const double& ethreshold,
106112
Real &cg_norm,
107113
Real &theta,
108114
Real &eigen,
@@ -113,7 +119,10 @@ class DiagoCG final
113119
void schmit_orth(const int& m, const ct::Tensor& psi, const ct::Tensor& sphi, ct::Tensor& phi_m);
114120

115121
// used in diag() for template replace Hamilt with Hamilt_PW
116-
void diag_mock(const ct::Tensor& prec, ct::Tensor& psi, ct::Tensor& eigen);
122+
void diag_mock(const ct::Tensor& prec,
123+
ct::Tensor& psi,
124+
ct::Tensor& eigen,
125+
const std::vector<double>& ethr_band);
117126

118127
bool test_exit_cond(const int& ntry, const int& notconv) const;
119128

source/module_hsolver/hsolver_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
279279
// prepare for the precondition of diagonalization
280280
std::vector<Real> precondition(psi.get_nbasis(), 0.0);
281281
std::vector<Real> eigenvalues(this->wfc_basis->nks * psi.get_nbands(), 0.0);
282-
ethr_band.resize(psi.get_nbands(), DiagoIterAssist<T, Device>::PW_DIAG_THR);
282+
ethr_band.resize(psi.get_nbands(), this->diag_thr);
283283

284284
/// Loop over k points for solve Hamiltonian to charge density
285285
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
@@ -467,7 +467,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
467467
.to_device<ct_Device>()
468468
.slice({0}, {psi.get_current_nbas()});
469469

470-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
470+
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, this->ethr_band, prec_tensor);
471471
// TODO: Double check tensormap's potential problem
472472
ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);
473473
}

source/module_hsolver/hsolver_pw.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ class HSolverPW
8484
int rank_in_pool = 0;
8585
int nproc_in_pool = 1;
8686

87+
std::vector<double> ethr_band;
88+
8789
private:
8890
/// @brief calculate the threshold for iterative-diagonalization for each band
8991
void cal_ethr_band(const double& wk, const double* wg, const double& ethr, std::vector<double>& ethrs);
9092

91-
std::vector<double> ethr_band;
92-
9393
#ifdef USE_PAW
9494
void paw_func_in_kloop(const int ik,
9595
const double tpiba);

source/module_hsolver/hsolver_pw_sdft.cpp

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,33 +29,11 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,
2929
const int nbands = psi.get_nbands();
3030
const int nks = psi.get_nk();
3131

32-
//---------------------------------------------------------------------------------------------------------------
33-
//---------------------------------for psi init guess!!!!--------------------------------------------------------
34-
//---------------------------------------------------------------------------------------------------------------
35-
// if (!PARAM.inp.psi_initializer && !this->initialed_psi && this->basis_type == "pw")
36-
// {
37-
// for (int ik = 0; ik < nks; ++ik)
38-
// {
39-
// /// update H(k) for each k point
40-
// pHamilt->updateHk(ik);
41-
42-
// if (nbands > 0 && GlobalV::MY_STOGROUP == 0)
43-
// {
44-
// /// update psi pointer for each k point
45-
// psi.fix_k(ik);
46-
47-
// /// for psi init guess!!!!
48-
// hamilt::diago_PAO_in_pw_k2(this->ctx, ik, psi, this->wfc_basis, this->pwf, pHamilt);
49-
// }
50-
// }
51-
// }
52-
//---------------------------------------------------------------------------------------------------------------
53-
//---------------------------------------------------------------------------------------------------------------
54-
//---------------------------------------------------------------------------------------------------------------
55-
5632
// prepare for the precondition of diagonalization
5733
std::vector<double> precondition(psi.get_nbasis(), 0.0);
5834

35+
this->ethr_band.resize(psi.get_nbands(), this->diag_thr);
36+
5937
// report if the specified diagonalization method is not supported
6038
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
6139
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))

source/module_hsolver/test/diago_cg_float_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ class DiagoCGPrepare
193193
ct::DataType::DT_FLOAT,
194194
ct::DeviceType::CpuDevice,
195195
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});
196-
197-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
196+
197+
std::vector<double> ethr_band(nband, 1e-5);
198+
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);
198199
// TODO: Double check tensormap's potential problem
199200
ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor);
200201
/**************************************************************/

source/module_hsolver/test/diago_cg_real_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,9 @@ class DiagoCGPrepare
196196
ct::DataType::DT_DOUBLE,
197197
ct::DeviceType::CpuDevice,
198198
ct::TensorShape({static_cast<int>(psi_local.get_current_nbas())})).slice({0}, {psi_local.get_current_nbas()});
199-
200-
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
199+
200+
std::vector<double> ethr_band(nband, 1e-5);
201+
cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, ethr_band, prec_tensor);
201202
// TODO: Double check tensormap's potential problem
202203
ct::TensorMap(psi_local.get_pointer(), psi_tensor, {psi_local.get_nbands(), psi_local.get_nbasis()}).sync(psi_tensor);
203204
/**************************************************************/

0 commit comments

Comments
 (0)