Skip to content

Commit 3844f04

Browse files
Test that solve inputs are destroyed in numba mode
1 parent 64e57ac commit 3844f04

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,10 +565,10 @@ def impl(
565565
norm = _xlange(A, order=order)
566566

567567
N = A.shape[1]
568-
LU, IPIV, INFO = _getrf(A)
568+
LU, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
569569
_solve_check(N, INFO)
570570

571-
X, INFO = _getrs(LU, B, IPIV, transposed)
571+
X, INFO = _getrs(LU, B, IPIV, transposed, overwrite_b=overwrite_b)
572572
_solve_check(N, INFO)
573573
RCOND, INFO = _xgecon(LU, norm, "1")
574574

tests/link/numba/test_slinalg.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def transpose_func(x, trans):
4242
@pytest.mark.filterwarnings(
4343
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
4444
)
45-
@pytest.mark.parametrize("overwrite_b", [True, False])
4645
def test_solve_triangular(
4746
b_func, b_size, lower, trans, unit_diag, complex, overwrite_b
4847
):
@@ -58,7 +57,7 @@ def test_solve_triangular(
5857
b = b_func("b", dtype=dtype)
5958

6059
X = pt.linalg.solve_triangular(
61-
A, b, lower=lower, trans=trans, unit_diagonal=unit_diag, overwrite_b=overwrite_b
60+
A, b, lower=lower, trans=trans, unit_diagonal=unit_diag
6261
)
6362
f = pytensor.function([A, b], X, mode="NUMBA")
6463

@@ -322,9 +321,7 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
322321
@pytest.mark.filterwarnings(
323322
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
324323
)
325-
@pytest.mark.parametrize("overwrite_a", [True, False])
326-
@pytest.mark.parametrize("overwrite_b", [True, False])
327-
def test_solve(b_func, b_size, assume_a, transposed, overwrite_a, overwrite_b):
324+
def test_solve(b_func, b_size, assume_a, transposed):
328325
A = pt.matrix("A", dtype=floatX)
329326
b = b_func("b", dtype=floatX)
330327

@@ -333,24 +330,39 @@ def test_solve(b_func, b_size, assume_a, transposed, overwrite_a, overwrite_b):
333330
b,
334331
lower=False,
335332
assume_a=assume_a,
336-
overwrite_a=overwrite_a,
337-
overwrite_b=overwrite_b,
338333
transposed=transposed,
339334
b_ndim=len(b_size),
340335
)
341-
f = pytensor.function([A, b], X, mode="NUMBA")
336+
f = pytensor.function(
337+
[pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA"
338+
)
339+
340+
A_val = np.random.normal(size=(5, 5)).astype(floatX)
342341

343-
A = np.random.normal(size=(5, 5)).astype(floatX)
344342
if assume_a in ["sym", "pos"]:
345-
A = A @ A.conj().T
346-
b = np.random.normal(size=b_size)
347-
b = b.astype(floatX)
343+
A_val = A_val @ A_val.conj().T
344+
A_val = np.asfortranarray(A_val)
345+
346+
b_val = np.random.normal(size=b_size)
347+
b_val = b_val.astype(floatX)
348+
b_val = np.asfortranarray(b_val)
349+
350+
A_val_copy = A_val.copy()
351+
b_val_copy = b_val.copy()
352+
353+
X_np = f(A_val, b_val)
354+
op = f.maker.fgraph.outputs[0].owner.op
355+
# overwrite_b is preferred when both inputs can be destroyed
356+
assert op.destroy_map == {0: [1]}
348357

349-
X_np = f(A, b)
350358
np.testing.assert_allclose(
351-
transpose_func(A, transposed) @ X_np, b, atol=ATOL, rtol=RTOL
359+
transpose_func(A_val_copy, transposed) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
352360
)
353361

362+
# Confirm input was destroyed
363+
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
364+
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
365+
354366

355367
@pytest.mark.parametrize(
356368
"b_func, b_size",

0 commit comments

Comments
 (0)