@@ -477,30 +477,38 @@ def test_solve_dtype(self):
477477@pytest .mark .parametrize ("permute_l" , [True , False ], ids = ["permute_l" , "no_permute_l" ])
478478@pytest .mark .parametrize ("p_indices" , [True , False ], ids = ["p_indices" , "no_p_indices" ])
479479@pytest .mark .parametrize ("complex" , [False , True ], ids = ["real" , "complex" ])
480- def test_lu_decomposition (permute_l , p_indices , complex ):
480+ @pytest .mark .parametrize ("shape" , [(3 , 5 , 5 ), (5 , 5 )], ids = ["batched" , "not_batched" ])
481+ def test_lu_decomposition (
482+ permute_l : bool , p_indices : bool , complex : bool , shape : tuple [int ]
483+ ):
481484 dtype = config .floatX if not complex else f"complex{ int (config .floatX [- 2 :]) * 2 } "
482- A = tensor ("A" , shape = (None , None ), dtype = dtype )
485+
486+ A = tensor ("A" , shape = shape , dtype = dtype )
483487 out = lu (A , permute_l = permute_l , p_indices = p_indices )
484488
485489 f = pytensor .function ([A ], out )
486490
487491 rng = np .random .default_rng (utt .fetch_seed ())
488- x = rng .normal (size = ( 5 , 5 ) ).astype (config .floatX )
492+ x = rng .normal (size = shape ).astype (config .floatX )
489493 if complex :
490- x = x + 1j * rng .normal (size = ( 5 , 5 ) ).astype (config .floatX )
494+ x = x + 1j * rng .normal (size = shape ).astype (config .floatX )
491495
492496 out = f (x )
493497
494498 if permute_l :
495499 PL , U = out
496- x_rebuilt = PL @ U
497500 elif p_indices :
498501 p , L , U = out
499- P = np .eye (5 )[p ]
500- x_rebuilt = P @ L @ U
502+ if len (shape ) == 2 :
503+ P = np .eye (5 )[p ]
504+ else :
505+ P = np .stack ([np .eye (5 )[idx ] for idx in p ])
506+ PL = np .einsum ("...nk,...km->...nm" , P , L )
501507 else :
502508 P , L , U = out
503- x_rebuilt = P @ L @ U
509+ PL = np .einsum ("...nk,...km->...nm" , P , L )
510+
511+ x_rebuilt = np .einsum ("...nk,...km->...nm" , PL , U )
504512
505513 np .testing .assert_allclose (x , x_rebuilt )
506514 scipy_out = scipy .linalg .lu (x , permute_l = permute_l , p_indices = p_indices )
@@ -512,9 +520,10 @@ def test_lu_decomposition(permute_l, p_indices, complex):
512520@pytest .mark .parametrize ("grad_case" , [0 , 1 , 2 ], ids = ["U_only" , "L_only" , "U_and_L" ])
513521@pytest .mark .parametrize ("permute_l" , [True , False ])
514522@pytest .mark .parametrize ("p_indices" , [True , False ])
515- def test_lu_grad (grad_case , permute_l , p_indices ):
523+ @pytest .mark .parametrize ("shape" , [(3 , 5 , 5 ), (5 , 5 )], ids = ["batched" , "not_batched" ])
524+ def test_lu_grad (grad_case , permute_l , p_indices , shape ):
516525 rng = np .random .default_rng (utt .fetch_seed ())
517- A_value = rng .normal (size = ( 5 , 5 ) )
526+ A_value = rng .normal (size = shape )
518527
519528 def f_pt (A ):
520529 out = lu (A , permute_l = permute_l , p_indices = p_indices )
0 commit comments