Skip to content

Commit c61bf79

Browse files
Tests pass
1 parent 23d2737 commit c61bf79

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def block_diag(*arrs):
333333
return block_diag
334334

335335

336-
def _xlamch():
336+
def _xlamch(kind: str = "E"):
337337
"""
338338
Placeholder for getting machine precision; used by linalg.solve. Not used by pytensor to numbify graphs.
339339
"""
@@ -830,8 +830,10 @@ def impl(
830830
transposed: bool,
831831
) -> np.ndarray:
832832
_solve_check_input_shapes(A, B)
833+
833834
x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
834835
_solve_check(A.shape[-1], info)
836+
835837
rcond, info = _sycon(A, ipiv, _xlange(A, order="I"))
836838
_solve_check(A.shape[-1], info, True, rcond)
837839

@@ -1024,8 +1026,10 @@ def impl(
10241026
transposed: bool,
10251027
) -> np.ndarray:
10261028
_solve_check_input_shapes(A, B)
1029+
10271030
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
10281031
_solve_check(A.shape[-1], info)
1032+
10291033
rcond, info = _pocon(x, _xlange(A))
10301034
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
10311035

tests/link/numba/test_slinalg.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,12 @@ def test_numba_Cholesky_raise_on(on_error):
182182

183183

184184
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
185-
@pytest.mark.parametrize("trans", [True, False], ids=["trans=True", "trans=False"])
186-
def test_numba_Cholesky_grad(lower, trans):
185+
def test_numba_Cholesky_grad(lower):
187186
rng = np.random.default_rng(utt.fetch_seed())
188187
L = rng.normal(size=(5, 5)).astype(floatX)
189188
X = L @ L.T
190189

191-
chol_op = partial(pt.linalg.cholesky, lower=lower, trans=trans)
190+
chol_op = partial(pt.linalg.cholesky, lower=lower)
192191
utt.verify_grad(chol_op, [X], mode="NUMBA")
193192

194193

@@ -359,10 +358,9 @@ def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
359358

360359
def A_func(x):
361360
if assume_a == "pos":
362-
x = x.T @ x
361+
x = x @ x.T
363362
elif assume_a == "sym":
364-
x = (x.T + x) / 2
365-
363+
x = (x + x.T) / 2
366364
return x
367365

368366
X = pt.linalg.solve(
@@ -376,13 +374,15 @@ def A_func(x):
376374
)
377375
op = f.maker.fgraph.outputs[0].owner.op
378376

379-
compare_numba_and_py(f.maker.fgraph, inputs=[A_func(A_val.copy()), b_val.copy()])
377+
compare_numba_and_py(
378+
([A, b], [X]), inputs=[A_val.copy(), b_val.copy()], inplace=True
379+
)
380380

381381
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
382382
A_val_copy = A_val.copy()
383383
b_val_copy = b_val.copy()
384384

385-
X_np = f(A_func(A_val), b_val)
385+
X_np = f(A_val, b_val)
386386

387387
# overwrite_b is preferred when both inputs can be destroyed
388388
assert op.destroy_map == {0: [1]}

0 commit comments

Comments
 (0)