Skip to content

Commit 3b2e5ad

Browse files
authored
Fix: Update function calls in pyabacus to align with new function signature in hpsi_func (#5176)
* fix some typos in `_hsolver.py` * fix some bugs caused by #5134
1 parent 4f2e453 commit 3b2e5ad

File tree

3 files changed

+26
-21
lines changed

3 files changed

+26
-21
lines changed

python/pyabacus/src/py_diago_dav_subspace.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,14 @@ class PyDiagoDavSubspace
110110
bool scf_type,
111111
hsolver::diag_comm_info comm_info
112112
) {
113-
auto hpsi_func = [mm_op] (std::complex<double> *hpsi_out,
114-
std::complex<double> *psi_in, const int nband_in,
115-
const int nbasis_in, const int band_index1,
116-
const int band_index2)
117-
{
113+
auto hpsi_func = [mm_op] (
114+
std::complex<double> *psi_in,
115+
std::complex<double> *hpsi_out,
116+
const int nband_in,
117+
const int nbasis_in,
118+
const int band_index1,
119+
const int band_index2
120+
) {
118121
// Note: numpy's py::array_t is row-major, but
119122
// our raw pointer-array is column-major
120123
py::array_t<std::complex<double>, py::array::f_style> psi({nbasis_in, band_index2 - band_index1 + 1});

python/pyabacus/src/py_diago_david.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ class PyDiagoDavid
109109
hsolver::diag_comm_info comm_info
110110
) {
111111
auto hpsi_func = [mm_op] (
112-
std::complex<double> *hpsi_out,
113-
std::complex<double> *psi_in,
112+
std::complex<double> *psi_in,
113+
std::complex<double> *hpsi_out,
114114
const int nband_in,
115115
const int nbasis_in,
116116
const int band_index1,

python/pyabacus/src/pyabacus/hsolver/_hsolver.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def rank(self) -> int: ...
1616
def nproc(self) -> int: ...
1717

1818
def dav_subspace(
19-
mm_op: Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
19+
mvv_op: Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
2020
init_v: NDArray[np.complex128],
2121
dim: int,
2222
num_eigs: int,
@@ -32,9 +32,10 @@ def dav_subspace(
3232
3333
Parameters
3434
----------
35-
mm_op : Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
36-
The operator to be diagonalized, which is a function that takes a matrix as input
37-
and returns a matrix mv_op(X) = H * X as output.
35+
mvv_op : Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
36+
The operator to be diagonalized, which is a function that takes a set of
37+
vectors X = [x1, ..., xN] as input and returns a matrix(vector block)
38+
mvv_op(X) = H * X ([Hx1, ..., HxN]) as output.
3839
init_v : NDArray[np.complex128]
3940
The initial guess for the eigenvectors.
4041
dim : int
@@ -68,8 +69,8 @@ def dav_subspace(
6869
v : NDArray[np.complex128]
6970
The eigenvectors corresponding to the eigenvalues.
7071
"""
71-
if not callable(mm_op):
72-
raise TypeError("mm_op must be a callable object.")
72+
if not callable(mvv_op):
73+
raise TypeError("mvv_op must be a callable object.")
7374

7475
if is_occupied is None:
7576
is_occupied = [True] * num_eigs
@@ -86,7 +87,7 @@ def dav_subspace(
8687
assert dav_ndim * num_eigs < dim * comm_info.nproc, "dav_ndim * num_eigs must be less than dim * comm_info.nproc."
8788

8889
_ = _diago_obj_dav_subspace.diag(
89-
mm_op,
90+
mvv_op,
9091
pre_condition,
9192
dav_ndim,
9293
tol,
@@ -103,7 +104,7 @@ def dav_subspace(
103104
return e, v
104105

105106
def davidson(
106-
mm_op: Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
107+
mvv_op: Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
107108
init_v: NDArray[np.complex128],
108109
dim: int,
109110
num_eigs: int,
@@ -119,9 +120,10 @@ def davidson(
119120
120121
Parameters
121122
----------
122-
mm_op : Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
123-
The operator to be diagonalized, which is a function that takes a matrix as input
124-
and returns a matrix mv_op(X) = H * X as output.
123+
mvv_op : Callable[[NDArray[np.complex128]], NDArray[np.complex128]],
124+
The operator to be diagonalized, which is a function that takes a set of
125+
vectors X = [x1, ..., xN] as input and returns a matrix(vector block)
126+
mvv_op(X) = H * X ([Hx1, ..., HxN]) as output.
125127
init_v : NDArray[np.complex128]
126128
The initial guess for the eigenvectors.
127129
dim : int
@@ -146,8 +148,8 @@ def davidson(
146148
v : NDArray[np.complex128]
147149
The eigenvectors corresponding to the eigenvalues.
148150
"""
149-
if not callable(mm_op):
150-
raise TypeError("mm_op must be a callable object.")
151+
if not callable(mvv_op):
152+
raise TypeError("mvv_op must be a callable object.")
151153

152154
if init_v.ndim != 1 or init_v.dtype != np.complex128:
153155
init_v = init_v.flatten().astype(np.complex128, order='C')
@@ -159,7 +161,7 @@ def davidson(
159161
comm_info = hsolver.diag_comm_info(0, 1)
160162

161163
_ = _diago_obj_dav_subspace.diag(
162-
mm_op,
164+
mvv_op,
163165
pre_condition,
164166
dav_ndim,
165167
tol,

0 commit comments

Comments
 (0)