1515from pytensor .tensor .slinalg import (
1616 Cholesky ,
1717 CholeskySolve ,
18+ LUSolve ,
1819 Solve ,
1920 SolveBase ,
2021 SolveTriangular ,
@@ -703,7 +704,8 @@ def test_lu_factor(permutation_indices):
703704
704705@pytest .mark .parametrize ("b_shape" , [(5 ,), (5 , 5 )])
705706@pytest .mark .parametrize ("trans" , [True , False ])
706- def test_lu_solve (b_shape : tuple [int ], trans ):
707+ @pytest .mark .parametrize ("use_op" , [True , False ])
708+ def test_lu_solve (b_shape : tuple [int ], trans , use_op ):
707709 def T (x ):
708710 if trans :
709711 return x .T
@@ -717,7 +719,13 @@ def T(x):
717719 b_val = rng .normal (size = b_shape ).astype (config .floatX )
718720
719721 LU_and_pivots = lu_factor (A )
720- x = lu_solve (LU_and_pivots , b , trans = trans )
722+
723+ if use_op :
724+ x = LUSolve (b_ndim = len (b_shape ), trans = trans , check_finite = False )(
725+ LU_and_pivots , b
726+ )
727+ else :
728+ x = lu_solve (LU_and_pivots , b , trans = trans )
721729
722730 f = pytensor .function ([A , b ], x )
723731 x_pt = f (A_val .copy (), b_val .copy ())
@@ -735,26 +743,6 @@ def T(x):
735743 )
736744 np .testing .assert_allclose (x_pt , x_sp )
737745
738- # import jax
739- # import jax.scipy as jsp
740- #
741- # def jax_f(A, b):
742- # LU_and_pivots = jsp.linalg.lu_factor(A)
743- # x = jsp.linalg.lu_solve(LU_and_pivots, b, trans=trans)
744- # return x.sum()
745-
746- # jax_res = jax.value_and_grad(jax_f, [0, 1])(A_val, b_val)
747- # g = grad(x.sum(), [A, b])
748- # fg = pytensor.function([A, b], [x.sum(), *g])
749-
750- # for a, b in zip(fg(A_val, b_val), [jax_res[0], *jax_res[1]]):
751- # print(a - b)
752-
753- # LU, pivots = pt.tensor('LU', shape=(5, 5)), pt.tensor('pivots', shape=(5,), dtype='int')
754- # x = lu_solve((LU, pivots), b)
755-
756- # LU_val, pivots_val = scipy.linalg.lu_factor(A_val)
757-
758746 utt .verify_grad (
759747 lambda A , b : lu_solve (lu_factor (A ), b , trans = trans ).sum (),
760748 pt = [A_val .copy (), b_val .copy ()],
@@ -776,15 +764,6 @@ def test_fn(A, b):
776764 x = lu_solve (lu_and_pivots , b )
777765 return x .sum ()
778766
779- # A = pt.tensor("A", shape=(5, 5))
780- # b = pt.tensor("b", shape=b_shape)
781-
782- # fg = pytensor.function([A, b], grad(test_fn(A, b), [A, b]))
783- # fg2 = pytensor.function([A, b], grad(pt.linalg.solve(A, b).sum(), [A, b]))
784-
785- # print(fg(A_val, b_val))
786- # print(fg2(A_val, b_val))
787-
788767 utt .verify_grad (test_fn , [A_val , b_val ], 3 , rng )
789768
790769
@@ -1065,3 +1044,33 @@ def test_block_diagonal_blockwise():
10651044 B = np .random .normal (size = (1 , batch_size , 4 , 4 )).astype (config .floatX )
10661045 result = block_diag (A , B ).eval ()
10671046 assert result .shape == (10 , batch_size , 6 , 6 )
1047+
1048+
1049+ def lu_solve_1 (A , b ):
1050+ lu , pivots = pt .linalg .lu_factor (A )
1051+ return pt .linalg .lu_solve ((lu , pivots ), b )
1052+
1053+
1054+ def lu_solve_2 (A , b , b_ndim = 1 , trans = 0 , check_finite = False ):
1055+ lu , pivots = pt .linalg .lu_factor (A )
1056+ return LUSolve (b_ndim = 1 , trans = 0 , check_finite = False )(lu , pivots , b )
1057+
1058+
1059+ @pytest .mark .parametrize (
1060+ "op" , [lu_solve_1 , lu_solve_2 , pt .linalg .solve ], ids = ["lu_1" , "lu_2" , "solve" ]
1061+ )
1062+ @pytest .mark .parametrize ("n" , [500 ])
1063+ def test_solve_methods (op , n , benchmark ):
1064+ A = pt .tensor ("A" , shape = (n , n ))
1065+ b = pt .tensor ("b" , shape = (n ,))
1066+
1067+ x = op (A , b )
1068+ gx = pt .grad (x .sum (), [A , b ])
1069+ f = pytensor .function ([A , b ], [x , * gx ])
1070+
1071+ A_val = np .random .normal (size = (n , n )).astype (config .floatX )
1072+ b_val = np .random .normal (size = (n ,)).astype (config .floatX )
1073+
1074+ # Trigger compilation if we're a jit mode
1075+ f (A_val , b_val )
1076+ benchmark (f , A_val , b_val )
0 commit comments