@@ -42,7 +42,6 @@ def transpose_func(x, trans):
4242@pytest .mark .filterwarnings (
4343 'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
4444)
45- @pytest .mark .parametrize ("overwrite_b" , [True , False ])
4645def test_solve_triangular (
4746 b_func , b_size , lower , trans , unit_diag , complex , overwrite_b
4847):
@@ -58,7 +57,7 @@ def test_solve_triangular(
5857 b = b_func ("b" , dtype = dtype )
5958
6059 X = pt .linalg .solve_triangular (
61- A , b , lower = lower , trans = trans , unit_diagonal = unit_diag , overwrite_b = overwrite_b
60+ A , b , lower = lower , trans = trans , unit_diagonal = unit_diag
6261 )
6362 f = pytensor .function ([A , b ], X , mode = "NUMBA" )
6463
@@ -322,9 +321,7 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
322321@pytest .mark .filterwarnings (
323322 'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
324323)
325- @pytest .mark .parametrize ("overwrite_a" , [True , False ])
326- @pytest .mark .parametrize ("overwrite_b" , [True , False ])
327- def test_solve (b_func , b_size , assume_a , transposed , overwrite_a , overwrite_b ):
324+ def test_solve (b_func , b_size , assume_a , transposed ):
328325 A = pt .matrix ("A" , dtype = floatX )
329326 b = b_func ("b" , dtype = floatX )
330327
@@ -333,24 +330,39 @@ def test_solve(b_func, b_size, assume_a, transposed, overwrite_a, overwrite_b):
333330 b ,
334331 lower = False ,
335332 assume_a = assume_a ,
336- overwrite_a = overwrite_a ,
337- overwrite_b = overwrite_b ,
338333 transposed = transposed ,
339334 b_ndim = len (b_size ),
340335 )
341- f = pytensor .function ([A , b ], X , mode = "NUMBA" )
336+ f = pytensor .function (
337+ [pytensor .In (A , mutable = True ), pytensor .In (b , mutable = True )], X , mode = "NUMBA"
338+ )
339+
340+ A_val = np .random .normal (size = (5 , 5 )).astype (floatX )
342341
343- A = np .random .normal (size = (5 , 5 )).astype (floatX )
344342 if assume_a in ["sym" , "pos" ]:
345- A = A @ A .conj ().T
346- b = np .random .normal (size = b_size )
347- b = b .astype (floatX )
343+ A_val = A_val @ A_val .conj ().T
344+ A_val = np .asfortranarray (A_val )
345+
346+ b_val = np .random .normal (size = b_size )
347+ b_val = b_val .astype (floatX )
348+ b_val = np .asfortranarray (b_val )
349+
350+ A_val_copy = A_val .copy ()
351+ b_val_copy = b_val .copy ()
352+
353+ X_np = f (A_val , b_val )
354+ op = f .maker .fgraph .outputs [0 ].owner .op
355+ # overwrite_b is preferred when both inputs can be destroyed
356+ assert op .destroy_map == {0 : [1 ]}
348357
349- X_np = f (A , b )
350358 np .testing .assert_allclose (
351- transpose_func (A , transposed ) @ X_np , b , atol = ATOL , rtol = RTOL
359+ transpose_func (A_val_copy , transposed ) @ X_np , b_val_copy , atol = ATOL , rtol = RTOL
352360 )
353361
362+ # Confirm input was destroyed
363+ assert (A_val == A_val_copy ).all () == (op .destroy_map .get (0 , None ) != [0 ])
364+ assert (b_val == b_val_copy ).all () == (op .destroy_map .get (0 , None ) != [1 ])
365+
354366
355367@pytest .mark .parametrize (
356368 "b_func, b_size" ,
0 commit comments