@@ -182,13 +182,12 @@ def test_numba_Cholesky_raise_on(on_error):
182182
183183
184184@pytest .mark .parametrize ("lower" , [True , False ], ids = ["lower=True" , "lower=False" ])
185- @pytest .mark .parametrize ("trans" , [True , False ], ids = ["trans=True" , "trans=False" ])
186- def test_numba_Cholesky_grad (lower , trans ):
185+ def test_numba_Cholesky_grad (lower ):
187186 rng = np .random .default_rng (utt .fetch_seed ())
188187 L = rng .normal (size = (5 , 5 )).astype (floatX )
189188 X = L @ L .T
190189
191- chol_op = partial (pt .linalg .cholesky , lower = lower , trans = trans )
190+ chol_op = partial (pt .linalg .cholesky , lower = lower )
192191 utt .verify_grad (chol_op , [X ], mode = "NUMBA" )
193192
194193
@@ -359,10 +358,9 @@ def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
359358
360359 def A_func (x ):
361360 if assume_a == "pos" :
362- x = x . T @ x
361+ x = x @ x . T
363362 elif assume_a == "sym" :
364- x = (x .T + x ) / 2
365-
363+ x = (x + x .T ) / 2
366364 return x
367365
368366 X = pt .linalg .solve (
@@ -376,13 +374,15 @@ def A_func(x):
376374 )
377375 op = f .maker .fgraph .outputs [0 ].owner .op
378376
379- compare_numba_and_py (f .maker .fgraph , inputs = [A_func (A_val .copy ()), b_val .copy ()])
377+ compare_numba_and_py (
378+ ([A , b ], [X ]), inputs = [A_val .copy (), b_val .copy ()], inplace = True
379+ )
380380
381381 # Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
382382 A_val_copy = A_val .copy ()
383383 b_val_copy = b_val .copy ()
384384
385- X_np = f (A_func ( A_val ) , b_val )
385+ X_np = f (A_val , b_val )
386386
387387 # overwrite_b is preferred when both inputs can be destroyed
388388 assert op .destroy_map == {0 : [1 ]}
0 commit comments