@@ -391,6 +391,8 @@ class LU(Op):
391391 def __init__ (
392392 self , * , permute_l = False , overwrite_a = False , check_finite = True , p_indices = False
393393 ):
394+ if permute_l and p_indices :
395+ raise ValueError ("Only one of permute_l and p_indices can be True" )
394396 self .permute_l = permute_l
395397 self .check_finite = check_finite
396398 self .p_indices = p_indices
@@ -432,12 +434,12 @@ def make_node(self, x):
432434 if self .permute_l :
433435 # In this case, L is actually P @ L
434436 return Apply (self , inputs = [x ], outputs = [L , U ])
435- elif self .p_indices :
436- p = tensor (shape = (x .type .shape [0 ],), dtype = p_dtype )
437- return Apply (self , inputs = [x ], outputs = [p , L , U ])
438- else :
439- P = tensor (shape = x .type .shape , dtype = p_dtype )
440- return Apply (self , inputs = [x ], outputs = [P , L , U ])
437+ if self .p_indices :
438+ p_indices = tensor (shape = (x .type .shape [0 ],), dtype = p_dtype )
439+ return Apply (self , inputs = [x ], outputs = [p_indices , L , U ])
440+
441+ P = tensor (shape = x .type .shape , dtype = p_dtype )
442+ return Apply (self , inputs = [x ], outputs = [P , L , U ])
441443
442444 def perform (self , node , inputs , outputs ):
443445 [A ] = inputs
@@ -479,30 +481,24 @@ def L_op(
479481 A = cast (TensorVariable , A )
480482
481483 if self .permute_l :
482- PL_bar , U_bar = output_grads
484+ # P has no gradient contribution (by assumption...), so PL_bar is the same as L_bar
485+ L_bar , U_bar = output_grads
483486
484487 # TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
485- P , L , U = lu ( # type: ignore
488+ # We need L, not PL. It's not possible to recover it from PL, though. So we need to do a new forward pass
489+ P_or_indices , L , U = lu ( # type: ignore
486490 A , permute_l = False , check_finite = self .check_finite , p_indices = False
487491 )
488492
489- # Permutation matrix is orthogonal
490- L_bar = (
491- P .T @ PL_bar
492- if not isinstance (PL_bar .type , DisconnectedType )
493- else pt .zeros_like (A )
494- )
495-
496- elif self .p_indices :
497- p , L , U = outputs
498-
499- # TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
500- P = pt .eye (A .shape [- 1 ])[p ]
501- _ , L_bar , U_bar = output_grads
502493 else :
503- P , L , U = outputs
494+ # In both other cases, there are 3 outputs. The first output will either be the permutation index itself,
495+ # or indices that can be used to reconstruct the permutation matrix.
496+ P_or_indices , L , U = outputs
504497 _ , L_bar , U_bar = output_grads
505498
499+ L = pytensor .printing .Print ("L" )(L )
500+ U = pytensor .printing .Print ("U" )(U )
501+
506502 L_bar = (
507503 L_bar if not isinstance (L_bar .type , DisconnectedType ) else pt .zeros_like (A )
508504 )
@@ -513,9 +509,17 @@ def L_op(
513509 x1 = ptb .tril (L .T @ L_bar , k = - 1 )
514510 x2 = ptb .triu (U_bar @ U .T )
515511
516- L_inv_x = solve_triangular (L .T , x1 + x2 , lower = False , unit_diagonal = True )
517- A_bar = P @ solve_triangular (U , L_inv_x .T , lower = False ).T
512+ LT_inv_x = solve_triangular (L .T , x1 + x2 , lower = False , unit_diagonal = True )
518513
514+ # Where B = P.T @ A is a change of variable to avoid the permutation matrix in the gradient derivation
515+ B_bar = solve_triangular (U , LT_inv_x .T , lower = False ).T
516+
517+ if not self .p_indices :
518+ A_bar = P_or_indices @ B_bar
519+ else :
520+ A_bar = B_bar [P_or_indices ]
521+
522+ A_bar = pytensor .printing .Print ("A_bar" )(A_bar )
519523 return [A_bar ]
520524
521525
@@ -556,16 +560,14 @@ def lu(
556560 U: TensorVariable
557561 Upper triangular matrix
558562 """
559- op = cast (
563+ return cast (
560564 tuple [TensorVariable , TensorVariable , TensorVariable ]
561565 | tuple [TensorVariable , TensorVariable ],
562566 Blockwise (
563- LU (permute_l = permute_l , check_finite = check_finite , p_indices = p_indices )
564- ),
567+ LU (permute_l = permute_l , p_indices = p_indices , check_finite = check_finite )
568+ )( a ) ,
565569 )
566570
567- return op (a )
568-
569571
570572class SolveTriangular (SolveBase ):
571573 """Solve a system of linear equations."""
0 commit comments