Skip to content

Commit 176ab32

Browse files
committed
Fix Numba pos solve condition number calculation
1 parent d772187 commit 176ab32

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def _posv(
864864
overwrite_b: bool,
865865
check_finite: bool,
866866
transposed: bool,
867-
) -> tuple[np.ndarray, int]:
867+
) -> tuple[np.ndarray, np.ndarray, int]:
868868
"""
869869
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
870870
"""
@@ -881,7 +881,8 @@ def posv_impl(
881881
check_finite: bool,
882882
transposed: bool,
883883
) -> Callable[
884-
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int]
884+
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
885+
tuple[np.ndarray, np.ndarray, int],
885886
]:
886887
ensure_lapack()
887888
_check_scipy_linalg_matrix(A, "solve")
@@ -939,8 +940,9 @@ def impl(
939940
)
940941

941942
if B_is_1d:
942-
return B_copy[..., 0], int_ptr_to_val(INFO)
943-
return B_copy, int_ptr_to_val(INFO)
943+
B_copy = B_copy[..., 0]
944+
945+
return A_copy, B_copy, int_ptr_to_val(INFO)
944946

945947
return impl
946948

@@ -1041,10 +1043,12 @@ def impl(
10411043
) -> np.ndarray:
10421044
_solve_check_input_shapes(A, B)
10431045

1044-
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
1046+
lu, x, info = _posv(
1047+
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
1048+
)
10451049
_solve_check(A.shape[-1], info)
10461050

1047-
rcond, info = _pocon(x, _xlange(A))
1051+
rcond, info = _pocon(lu, _xlange(A))
10481052
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
10491053

10501054
return x

0 commit comments

Comments
 (0)