Skip to content

Commit 762d012

Browse files
committed
fix pyabacus
1 parent 681b778 commit 762d012

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

python/pyabacus/src/hsolver/py_diago_david.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class PyDiagoDavid
104104
std::vector<double> precond_vec,
105105
int dav_ndim,
106106
double tol,
107+
std::vector<double> diag_ethr,
107108
int max_iter,
108109
bool use_paw,
109110
hsolver::diag_comm_info comm_info
@@ -146,7 +147,7 @@ class PyDiagoDavid
146147
comm_info
147148
);
148149

149-
return obj->diag(hpsi_func, spsi_func, nbasis, psi, eigenvalue, tol, max_iter);
150+
return obj->diag(hpsi_func, spsi_func, nbasis, psi, eigenvalue, diag_ethr, max_iter);
150151
}
151152

152153
private:

python/pyabacus/src/pyabacus/hsolver/_hsolver.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def davidson(
117117
precondition: NDArray[np.float64],
118118
dav_ndim: int = 2,
119119
tol: float = 1e-2,
120+
diag_ethr: Union[List[float], None] = None,
120121
max_iter: int = 1000,
121122
use_paw: bool = False,
122123
# scf_type: bool = False
@@ -164,12 +165,16 @@ def davidson(
164165
_diago_obj_david.init_eigenvalue()
165166

166167
comm_info = diag_comm_info(0, 1)
168+
169+
if diag_ethr is None:
170+
diag_ethr = [tol] * num_eigs
167171

168172
_ = _diago_obj_david.diag(
169173
mvv_op,
170174
precondition,
171175
dav_ndim,
172176
tol,
177+
diag_ethr,
173178
max_iter,
174179
use_paw,
175180
comm_info

0 commit comments

Comments
 (0)