Skip to content

Commit f330a9f

Browse files
committed
Fix Numba pos solve condition number calculation
1 parent f736430 commit f330a9f

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ def _posv(
874874
overwrite_b: bool,
875875
check_finite: bool,
876876
transposed: bool,
877-
) -> tuple[np.ndarray, int]:
877+
) -> tuple[np.ndarray, np.ndarray, int]:
878878
"""
879879
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
880880
"""
@@ -891,7 +891,8 @@ def posv_impl(
891891
check_finite: bool,
892892
transposed: bool,
893893
) -> Callable[
894-
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int]
894+
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
895+
tuple[np.ndarray, np.ndarray, int],
895896
]:
896897
ensure_lapack()
897898
_check_scipy_linalg_matrix(A, "solve")
@@ -952,8 +953,9 @@ def impl(
952953
)
953954

954955
if B_is_1d:
955-
return B_copy[..., 0], int_ptr_to_val(INFO)
956-
return B_copy, int_ptr_to_val(INFO)
956+
B_copy = B_copy[..., 0]
957+
958+
return A_copy, B_copy, int_ptr_to_val(INFO)
957959

958960
return impl
959961

@@ -1054,10 +1056,12 @@ def impl(
10541056
) -> np.ndarray:
10551057
_solve_check_input_shapes(A, B)
10561058

1057-
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
1059+
lu, x, info = _posv(
1060+
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
1061+
)
10581062
_solve_check(A.shape[-1], info)
10591063

1060-
rcond, info = _pocon(x, _xlange(A))
1064+
rcond, info = _pocon(lu, _xlange(A))
10611065
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
10621066

10631067
return x

0 commit comments

Comments
 (0)