Skip to content

Commit 330d6a3

Browse files
committed
update pyabacus
1 parent 8b2ec37 commit 330d6a3

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

python/pyabacus/src/hsolver/py_diago_cg.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,13 @@ class PyDiagoCG
114114
);
115115
}
116116

117-
void diag(
118-
std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
119-
int diag_ndim,
120-
double tol,
121-
bool need_subspace,
122-
bool scf_type,
123-
int nproc_in_pool = 1
117+
void diag(std::function<py::array_t<std::complex<double>>(py::array_t<std::complex<double>>)> mm_op,
118+
int diag_ndim,
119+
double tol,
120+
const std::vector<double>& diag_ethr,
121+
bool need_subspace,
122+
bool scf_type,
123+
int nproc_in_pool = 1
124124
) {
125125
const std::string basis_type = "pw";
126126
const std::string calculation = scf_type ? "scf" : "nscf";
@@ -171,7 +171,7 @@ class PyDiagoCG
171171
nproc_in_pool
172172
);
173173

174-
cg->diag(hpsi_func, spsi_func, *psi, *eig, *prec);
174+
cg->diag(hpsi_func, spsi_func, *psi, *eig, diag_ethr, *prec);
175175
}
176176

177177
private:

python/pyabacus/src/pyabacus/hsolver/_hsolver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def cg(
244244
if init_v.ndim == 2:
245245
init_v = init_v.T
246246
init_v = init_v.flatten().astype(np.complex128, order='C')
247+
248+
diag_ethr = [tol] * num_eigs
247249

248250
_diago_obj_cg = diago_cg(dim, num_eigs)
249251
_diago_obj_cg.set_psi(init_v)
@@ -255,6 +257,7 @@ def cg(
255257
mvv_op,
256258
max_iter,
257259
tol,
260+
diag_ethr
258261
need_subspace,
259262
scf_type,
260263
nproc_in_pool

0 commit comments

Comments
 (0)