@@ -226,6 +226,7 @@ def __init__(
226226 ):
227227 self .lower = lower
228228 self .check_finite = check_finite
229+
229230 assert b_ndim in (1 , 2 )
230231 self .b_ndim = b_ndim
231232 if b_ndim == 1 :
@@ -303,10 +304,14 @@ def L_op(self, inputs, outputs, output_gradients):
303304
304305 solve_op = type (self )(** props_dict )
305306
306- b_bar = solve_op (A .T , c_bar )
307+ b_bar = solve_op (A .mT , c_bar )
307308 # force outer product if vector second input
308309 A_bar = - ptm .outer (b_bar , c ) if c .ndim == 1 else - b_bar .dot (c .T )
309310
311+ if props_dict .get ("unit_diagonal" , False ):
312+ n = A_bar .shape [- 1 ]
313+ A_bar = A_bar [pt .arange (n ), pt .arange (n )].set (pt .zeros (n ))
314+
310315 return [A_bar , b_bar ]
311316
312317
@@ -577,12 +582,42 @@ def lu(
577582 )
578583
579584
585+ def _pivot_to_permutation (pivots ):
586+ """
587+ Converts a sequence of row exchanges to a permutation matrix that represents the same row exchanges. This
588+ represents the inverse permutation, which can be used to reconstruct the original matrix from its LU factorization.
589+ To get the actual permutation, the inverse permutation must be argsorted.
590+ """
591+
592+ def step (i , permutation , swaps ):
593+ j = swaps [i ]
594+ x = permutation [i ]
595+ y = permutation [j ]
596+
597+ permutation = permutation [i ].set (y )
598+ return permutation [j ].set (x )
599+
600+ pivots = as_tensor_variable (pivots )
601+ n = pivots .shape [0 ]
602+ p_inv , _ = pytensor .scan (
603+ step ,
604+ sequences = [pt .arange (n .copy ())],
605+ outputs_info = [pt .arange (n .copy ())],
606+ non_sequences = [pivots ],
607+ )
608+
609+ return p_inv [- 1 ]
610+
611+
580612class LUFactor (Op ):
581- __props__ = ("overwrite_a" , "check_finite" )
613+ __props__ = ("overwrite_a" , "check_finite" , "permutation_indices" )
582614
583- def __init__ (self , * , overwrite_a = False , check_finite = True ):
615+ def __init__ (
616+ self , * , overwrite_a = False , check_finite = True , permutation_indices = False
617+ ):
584618 self .overwrite_a = overwrite_a
585619 self .check_finite = check_finite
620+ self .permutation_indices = permutation_indices
586621 self .gufunc_signature = "(m,m)->(m,m),(m)"
587622
588623 if self .overwrite_a :
@@ -596,8 +631,9 @@ def make_node(self, A):
596631 )
597632
598633 LU = matrix (shape = A .type .shape , dtype = A .type .dtype )
599- pivots = vector (shape = (A .type .shape [0 ],), dtype = "int32" )
600- return Apply (self , [A ], [LU , pivots ])
634+ pivots_or_permutations = vector (shape = (A .type .shape [0 ],), dtype = "int32" )
635+
636+ return Apply (self , [A ], [LU , pivots_or_permutations ])
601637
602638 def infer_shape (self , fgraph , node , shapes ):
603639 n = shapes [0 ][0 ]
@@ -613,25 +649,40 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
613649
614650 def perform (self , node , inputs , outputs ):
615651 A = inputs [0 ]
616- LU , pivots = scipy_linalg .lu_factor (
617- A ,
618- overwrite_a = self .overwrite_a ,
619- check_finite = self .check_finite ,
620- )
652+
653+ if self .permutation_indices :
654+ p , L , U = cast (
655+ tuple [np .ndarray , np .ndarray , np .ndarray ],
656+ scipy_linalg .lu (
657+ A ,
658+ overwrite_a = self .overwrite_a ,
659+ check_finite = self .check_finite ,
660+ p_indices = True ,
661+ permute_l = False ,
662+ ),
663+ )
664+ LU = np .tril (L , k = - 1 ) + U
665+
666+ else :
667+ LU , p = scipy_linalg .lu_factor (
668+ A , overwrite_a = self .overwrite_a , check_finite = self .check_finite
669+ )
621670
622671 outputs [0 ][0 ] = LU
623- outputs [1 ][0 ] = pivots
672+ outputs [1 ][0 ] = p
624673
625674 def L_op (self , inputs , outputs , output_gradients ):
626- A = inputs [ 0 ]
675+ [ A ] = inputs
627676 LU_bar , _ = output_gradients
677+ LU , p_indices = outputs
628678
629- # We need the permutation matrix P, not the pivot indices. Easiest way is to just do another LU forward.
630- # Alternative is to do a scan over the pivot indices to convert them to permutation indices. I don't know if
631- # that's faster or slower.
632- P , L , U = lu (
633- A , permute_l = False , check_finite = self .check_finite , p_indices = False
634- )
679+ eye = ptb .identity_like (A )
680+ L = cast (TensorVariable , ptb .tril (LU , k = - 1 ) + eye )
681+ U = cast (TensorVariable , ptb .triu (LU ))
682+
683+ if not self .permutation_indices :
684+ p_indices_inv = _pivot_to_permutation (cast (TensorVariable , p_indices ))
685+ p_indices = pt .argsort (p_indices_inv )
635686
636687 # Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
637688 L_bar = ptb .tril (LU_bar , k = - 1 )
@@ -642,13 +693,14 @@ def L_op(self, inputs, outputs, output_gradients):
642693 x2 = ptb .triu (U_bar @ U .T )
643694
644695 LT_inv_x = solve_triangular (L .T , x1 + x2 , lower = False , unit_diagonal = True )
645- A_bar = P @ solve_triangular (U , LT_inv_x .T , lower = False ).T
696+ B_bar = solve_triangular (U , LT_inv_x .T , lower = False ).T
697+ A_bar = B_bar [p_indices ]
646698
647699 return [A_bar ]
648700
649701
650702def lu_factor (
651- a : TensorLike , * , check_finite = True
703+ a : TensorLike , * , check_finite : bool = True , permutation_indices : bool = False
652704) -> tuple [TensorVariable , TensorVariable ]:
653705 """
654706 LU factorization with partial pivoting.
@@ -659,21 +711,63 @@ def lu_factor(
659711 Matrix to be factorized
660712 check_finite: bool
661713 Whether to check that the input matrix contains only finite numbers.
714+ permutation_indices: bool
715+ If True, returns permutation indices such that L[p] @ U = A. Otherwise returns the pivot indices, which give
716+ a record of row swaps that occured at each iteration of the LU factorization. Default is False, which matches
717+ the behavior of scipy.linalg.lu_factor.
662718
663719 Returns
664720 -------
665721 LU: TensorVariable
666722 LU decomposition of `a`
667- pivots: TensorVariable
668- Permutation indices
723+ pivots_or_permutations: TensorVariable
724+ An array of integers representing either the pivot indices or permutation indices, depending on the value of
725+ `permutation_indices`.
669726 """
670727
671728 return cast (
672729 tuple [TensorVariable , TensorVariable ],
673- Blockwise (LUFactor (check_finite = check_finite ))(a ),
730+ Blockwise (
731+ LUFactor (check_finite = check_finite , permutation_indices = permutation_indices )
732+ )(a ),
674733 )
675734
676735
736+ def lu_solve (
737+ LU_and_pivots : tuple [TensorVariable , TensorVariable ],
738+ b : TensorVariable ,
739+ trans = False ,
740+ b_ndim = None ,
741+ check_finite = True ,
742+ ):
743+ LU , pivots = LU_and_pivots
744+ inv_permutation = _pivot_to_permutation (pivots )
745+
746+ x = b [inv_permutation ] if not trans else b
747+
748+ x = solve_triangular (
749+ LU ,
750+ x ,
751+ lower = not trans ,
752+ unit_diagonal = not trans ,
753+ trans = trans ,
754+ b_ndim = b_ndim ,
755+ check_finite = check_finite ,
756+ )
757+
758+ x = solve_triangular (
759+ LU ,
760+ x ,
761+ lower = trans ,
762+ unit_diagonal = trans ,
763+ trans = trans ,
764+ b_ndim = b_ndim ,
765+ check_finite = check_finite ,
766+ )
767+
768+ return x [pt .argsort (inv_permutation )] if trans else x
769+
770+
677771class SolveTriangular (SolveBase ):
678772 """Solve a system of linear equations."""
679773
@@ -688,6 +782,9 @@ class SolveTriangular(SolveBase):
688782 def __init__ (self , * , unit_diagonal = False , ** kwargs ):
689783 if kwargs .get ("overwrite_a" , False ):
690784 raise ValueError ("overwrite_a is not supported for SolverTriangulare" )
785+
786+ # There's a naming inconsistency between solve_triangular (trans) and solve (transposed). Internally, we can use
787+ # transpose everywhere, but expose the same API as scipy.linalg.solve_triangular
691788 super ().__init__ (** kwargs )
692789 self .unit_diagonal = unit_diagonal
693790
@@ -1546,4 +1643,5 @@ def block_diag(*matrices: TensorVariable):
15461643 "cho_solve" ,
15471644 "lu" ,
15481645 "lu_factor" ,
1646+ "lu_solve" ,
15491647]
0 commit comments