@@ -401,16 +401,39 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
401401 assert_allclose (x , x_sp )
402402
403403
404+ solve_test_cases = [
405+ ("gen" , False , False ),
406+ ("gen" , False , True ),
407+ ("sym" , False , False ),
408+ ("sym" , True , False ),
409+ ("sym" , True , True ),
410+ ("pos" , False , False ),
411+ ("pos" , True , False ),
412+ ("pos" , True , True ),
413+ ]
414+ solve_test_ids = [
415+ f'{ assume_a } _{ "lower" if lower else "upper" } _{ "A^T" if transposed else "A" } '
416+ for assume_a , lower , transposed in solve_test_cases
417+ ]
418+
419+
404420@pytest .mark .parametrize (
405421 "b_shape" ,
406422 [(5 , 1 ), (5 , 5 ), (5 ,)],
407423 ids = ["b_col_vec" , "b_matrix" , "b_vec" ],
408424)
409- @pytest .mark .parametrize ("assume_a" , ["gen" , "sym" , "pos" ], ids = str )
425+ @pytest .mark .parametrize (
426+ "assume_a, lower, transposed" , solve_test_cases , ids = solve_test_ids
427+ )
410428@pytest .mark .filterwarnings (
411429 'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
412430)
413- def test_solve (b_shape : tuple [int ], assume_a : Literal ["gen" , "sym" , "pos" ]):
431+ def test_solve (
432+ b_shape : tuple [int ],
433+ assume_a : Literal ["gen" , "sym" , "pos" ],
434+ lower : bool ,
435+ transposed : bool ,
436+ ):
414437 A = pt .matrix ("A" , dtype = floatX )
415438 b = pt .tensor ("b" , shape = b_shape , dtype = floatX )
416439
@@ -424,10 +447,17 @@ def A_func(x):
424447 x = (x + x .T ) / 2
425448 return x
426449
450+ def T (x ):
451+ if transposed :
452+ return x .T
453+ return x
454+
427455 X = pt .linalg .solve (
428456 A_func (A ),
429457 b ,
430458 assume_a = assume_a ,
459+ lower = lower ,
460+ transposed = transposed ,
431461 b_ndim = len (b_shape ),
432462 )
433463 f = pytensor .function (
@@ -459,13 +489,18 @@ def A_func(x):
459489
460490 # Test that the result is numerically correct. Need to use the unmodified copy
461491 np .testing .assert_allclose (
462- A_func (A_val_copy ) @ X_np , b_val_copy , atol = ATOL , rtol = RTOL
492+ T ( A_func (A_val_copy ) ) @ X_np , b_val_copy , atol = ATOL , rtol = RTOL
463493 )
464494
465495 # See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
466496 utt .verify_grad (
467497 lambda A , b : pt .linalg .solve (
468- A_func (A ), b , lower = False , assume_a = assume_a , b_ndim = len (b_shape )
498+ A_func (A ),
499+ b ,
500+ lower = lower ,
501+ transposed = transposed ,
502+ assume_a = assume_a ,
503+ b_ndim = len (b_shape ),
469504 ),
470505 [A_val_copy , b_val_copy ],
471506 mode = "NUMBA" ,
0 commit comments