@@ -358,16 +358,39 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
358358 assert_allclose (x , x_sp )
359359
360360
361+ solve_test_cases = [
362+ ("gen" , False , False ),
363+ ("gen" , False , True ),
364+ ("sym" , False , False ),
365+ ("sym" , True , False ),
366+ ("sym" , True , True ),
367+ ("pos" , False , False ),
368+ ("pos" , True , False ),
369+ ("pos" , True , True ),
370+ ]
371+ solve_test_ids = [
372+ f'{ assume_a } _{ "lower" if lower else "upper" } _{ "A^T" if transposed else "A" } '
373+ for assume_a , lower , transposed in solve_test_cases
374+ ]
375+
376+
361377@pytest .mark .parametrize (
362378 "b_shape" ,
363379 [(5 , 1 ), (5 , 5 ), (5 ,)],
364380 ids = ["b_col_vec" , "b_matrix" , "b_vec" ],
365381)
366- @pytest .mark .parametrize ("assume_a" , ["gen" , "sym" , "pos" ], ids = str )
382+ @pytest .mark .parametrize (
383+ "assume_a, lower, transposed" , solve_test_cases , ids = solve_test_ids
384+ )
367385@pytest .mark .filterwarnings (
368386 'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
369387)
370- def test_solve (b_shape : tuple [int ], assume_a : Literal ["gen" , "sym" , "pos" ]):
388+ def test_solve (
389+ b_shape : tuple [int ],
390+ assume_a : Literal ["gen" , "sym" , "pos" ],
391+ lower : bool ,
392+ transposed : bool ,
393+ ):
371394 A = pt .matrix ("A" , dtype = floatX )
372395 b = pt .tensor ("b" , shape = b_shape , dtype = floatX )
373396
@@ -381,10 +404,17 @@ def A_func(x):
381404 x = (x + x .T ) / 2
382405 return x
383406
407+ def T (x ):
408+ if transposed :
409+ return x .T
410+ return x
411+
384412 X = pt .linalg .solve (
385413 A_func (A ),
386414 b ,
387415 assume_a = assume_a ,
416+ lower = lower ,
417+ transposed = transposed ,
388418 b_ndim = len (b_shape ),
389419 )
390420 f = pytensor .function (
@@ -416,13 +446,18 @@ def A_func(x):
416446
417447 # Test that the result is numerically correct. Need to use the unmodified copy
418448 np .testing .assert_allclose (
419- A_func (A_val_copy ) @ X_np , b_val_copy , atol = ATOL , rtol = RTOL
449+ T ( A_func (A_val_copy ) ) @ X_np , b_val_copy , atol = ATOL , rtol = RTOL
420450 )
421451
422452 # See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
423453 utt .verify_grad (
424454 lambda A , b : pt .linalg .solve (
425- A_func (A ), b , lower = False , assume_a = assume_a , b_ndim = len (b_shape )
455+ A_func (A ),
456+ b ,
457+ lower = lower ,
458+ transposed = transposed ,
459+ assume_a = assume_a ,
460+ b_ndim = len (b_shape ),
426461 ),
427462 [A_val_copy , b_val_copy ],
428463 mode = "NUMBA" ,
0 commit comments