Skip to content

Commit 9a0926b

Browse files
committed
Fix Numba pos solve condition number calculation
1 parent 87ba8ed commit 9a0926b

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ def _posv(
877877
overwrite_b: bool,
878878
check_finite: bool,
879879
transposed: bool,
880-
) -> tuple[np.ndarray, int]:
880+
) -> tuple[np.ndarray, np.ndarray, int]:
881881
"""
882882
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
883883
"""
@@ -894,7 +894,8 @@ def posv_impl(
894894
check_finite: bool,
895895
transposed: bool,
896896
) -> Callable[
897-
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int]
897+
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
898+
tuple[np.ndarray, np.ndarray, int],
898899
]:
899900
ensure_lapack()
900901
_check_scipy_linalg_matrix(A, "solve")
@@ -955,8 +956,9 @@ def impl(
955956
)
956957

957958
if B_is_1d:
958-
return B_copy[..., 0], int_ptr_to_val(INFO)
959-
return B_copy, int_ptr_to_val(INFO)
959+
B_copy = B_copy[..., 0]
960+
961+
return A_copy, B_copy, int_ptr_to_val(INFO)
960962

961963
return impl
962964

@@ -1057,10 +1059,12 @@ def impl(
10571059
) -> np.ndarray:
10581060
_solve_check_input_shapes(A, B)
10591061

1060-
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
1062+
lu, x, info = _posv(
1063+
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
1064+
)
10611065
_solve_check(A.shape[-1], info)
10621066

1063-
rcond, info = _pocon(x, _xlange(A))
1067+
rcond, info = _pocon(lu, _xlange(A))
10641068
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
10651069

10661070
return x

0 commit comments

Comments
 (0)