1010
1111import pytensor
1212import pytensor .tensor as pt
13+ from pytensor .compile .builders import OpFromGraph
1314from pytensor .gradient import DisconnectedType
1415from pytensor .graph .basic import Apply
1516from pytensor .graph .op import Op
@@ -226,6 +227,7 @@ def __init__(
226227 ):
227228 self .lower = lower
228229 self .check_finite = check_finite
230+
229231 assert b_ndim in (1 , 2 )
230232 self .b_ndim = b_ndim
231233 if b_ndim == 1 :
@@ -303,10 +305,14 @@ def L_op(self, inputs, outputs, output_gradients):
303305
304306 solve_op = type (self )(** props_dict )
305307
306- b_bar = solve_op (A .T , c_bar )
308+ b_bar = solve_op (A .mT , c_bar )
307309 # force outer product if vector second input
308310 A_bar = - ptm .outer (b_bar , c ) if c .ndim == 1 else - b_bar .dot (c .T )
309311
312+ if props_dict .get ("unit_diagonal" , False ):
313+ n = A_bar .shape [- 1 ]
314+ A_bar = A_bar [pt .arange (n ), pt .arange (n )].set (pt .zeros (n ))
315+
310316 return [A_bar , b_bar ]
311317
312318
@@ -577,12 +583,42 @@ def lu(
577583 )
578584
579585
586+ def _pivot_to_permutation (pivots ):
587+ """
588+ Converts a sequence of row exchanges to a permutation matrix that represents the same row exchanges. This
589+ represents the inverse permutation, which can be used to reconstruct the original matrix from its LU factorization.
590+ To get the actual permutation, the inverse permutation must be argsorted.
591+ """
592+
593+ def step (i , permutation , swaps ):
594+ j = swaps [i ]
595+ x = permutation [i ]
596+ y = permutation [j ]
597+
598+ permutation = permutation [i ].set (y )
599+ return permutation [j ].set (x )
600+
601+ pivots = as_tensor_variable (pivots )
602+ n = pivots .shape [0 ]
603+ p_inv , _ = pytensor .scan (
604+ step ,
605+ sequences = [pt .arange (n .copy ())],
606+ outputs_info = [pt .arange (n .copy ())],
607+ non_sequences = [pivots ],
608+ )
609+
610+ return p_inv [- 1 ]
611+
612+
580613class LUFactor (Op ):
581- __props__ = ("overwrite_a" , "check_finite" )
614+ __props__ = ("overwrite_a" , "check_finite" , "permutation_indices" )
582615
583- def __init__ (self , * , overwrite_a = False , check_finite = True ):
616+ def __init__ (
617+ self , * , overwrite_a = False , check_finite = True , permutation_indices = False
618+ ):
584619 self .overwrite_a = overwrite_a
585620 self .check_finite = check_finite
621+ self .permutation_indices = permutation_indices
586622 self .gufunc_signature = "(m,m)->(m,m),(m)"
587623
588624 if self .overwrite_a :
@@ -596,8 +632,9 @@ def make_node(self, A):
596632 )
597633
598634 LU = matrix (shape = A .type .shape , dtype = A .type .dtype )
599- pivots = vector (shape = (A .type .shape [0 ],), dtype = "int32" )
600- return Apply (self , [A ], [LU , pivots ])
635+ pivots_or_permutations = vector (shape = (A .type .shape [0 ],), dtype = "int32" )
636+
637+ return Apply (self , [A ], [LU , pivots_or_permutations ])
601638
602639 def infer_shape (self , fgraph , node , shapes ):
603640 n = shapes [0 ][0 ]
@@ -613,25 +650,40 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
613650
614651 def perform (self , node , inputs , outputs ):
615652 A = inputs [0 ]
616- LU , pivots = scipy_linalg .lu_factor (
617- A ,
618- overwrite_a = self .overwrite_a ,
619- check_finite = self .check_finite ,
620- )
653+
654+ if self .permutation_indices :
655+ p , L , U = cast (
656+ tuple [np .ndarray , np .ndarray , np .ndarray ],
657+ scipy_linalg .lu (
658+ A ,
659+ overwrite_a = self .overwrite_a ,
660+ check_finite = self .check_finite ,
661+ p_indices = True ,
662+ permute_l = False ,
663+ ),
664+ )
665+ LU = np .tril (L , k = - 1 ) + U
666+
667+ else :
668+ LU , p = scipy_linalg .lu_factor (
669+ A , overwrite_a = self .overwrite_a , check_finite = self .check_finite
670+ )
621671
622672 outputs [0 ][0 ] = LU
623- outputs [1 ][0 ] = pivots
673+ outputs [1 ][0 ] = p
624674
625675 def L_op (self , inputs , outputs , output_gradients ):
626- A = inputs [ 0 ]
676+ [ A ] = inputs
627677 LU_bar , _ = output_gradients
678+ LU , p_indices = outputs
628679
629- # We need the permutation matrix P, not the pivot indices. Easiest way is to just do another LU forward.
630- # Alternative is to do a scan over the pivot indices to convert them to permutation indices. I don't know if
631- # that's faster or slower.
632- P , L , U = lu (
633- A , permute_l = False , check_finite = self .check_finite , p_indices = False
634- )
680+ eye = ptb .identity_like (A )
681+ L = cast (TensorVariable , ptb .tril (LU , k = - 1 ) + eye )
682+ U = cast (TensorVariable , ptb .triu (LU ))
683+
684+ if not self .permutation_indices :
685+ p_indices_inv = _pivot_to_permutation (cast (TensorVariable , p_indices ))
686+ p_indices = pt .argsort (p_indices_inv )
635687
636688 # Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
637689 L_bar = ptb .tril (LU_bar , k = - 1 )
@@ -642,13 +694,14 @@ def L_op(self, inputs, outputs, output_gradients):
642694 x2 = ptb .triu (U_bar @ U .T )
643695
644696 LT_inv_x = solve_triangular (L .T , x1 + x2 , lower = False , unit_diagonal = True )
645- A_bar = P @ solve_triangular (U , LT_inv_x .T , lower = False ).T
697+ B_bar = solve_triangular (U , LT_inv_x .T , lower = False ).T
698+ A_bar = B_bar [p_indices ]
646699
647700 return [A_bar ]
648701
649702
650703def lu_factor (
651- a : TensorLike , * , check_finite = True
704+ a : TensorLike , * , check_finite : bool = True , permutation_indices : bool = False
652705) -> tuple [TensorVariable , TensorVariable ]:
653706 """
654707 LU factorization with partial pivoting.
@@ -659,21 +712,112 @@ def lu_factor(
659712 Matrix to be factorized
660713 check_finite: bool
661714 Whether to check that the input matrix contains only finite numbers.
715+ permutation_indices: bool
716+ If True, returns permutation indices such that L[p] @ U = A. Otherwise returns the pivot indices, which give
717+ a record of row swaps that occured at each iteration of the LU factorization. Default is False, which matches
718+ the behavior of scipy.linalg.lu_factor.
662719
663720 Returns
664721 -------
665722 LU: TensorVariable
666723 LU decomposition of `a`
667- pivots: TensorVariable
668- Permutation indices
724+ pivots_or_permutations: TensorVariable
725+ An array of integers representing either the pivot indices or permutation indices, depending on the value of
726+ `permutation_indices`.
669727 """
670728
671729 return cast (
672730 tuple [TensorVariable , TensorVariable ],
673- Blockwise (LUFactor (check_finite = check_finite ))(a ),
731+ Blockwise (
732+ LUFactor (check_finite = check_finite , permutation_indices = permutation_indices )
733+ )(a ),
674734 )
675735
676736
737+ class LUSolve (OpFromGraph ):
738+ """Solve a system of linear equations given the LU decomposition of the matrix."""
739+
740+ __props__ = ("trans" , "b_ndim" , "check_finite" , "overwrite_b" )
741+
742+ def __init__ (
743+ self ,
744+ * args ,
745+ trans : bool = False ,
746+ b_ndim : int | None = None ,
747+ check_finite : bool = False ,
748+ overwrite_b : bool = False ,
749+ ** kwargs ,
750+ ):
751+ self .trans = trans
752+ self .b_ndim = b_ndim
753+ self .check_finite = check_finite
754+ self .overwrite_b = overwrite_b
755+
756+ super ().__init__ (* args , ** kwargs )
757+
758+
759+ def lu_solve (
760+ LU_and_pivots : tuple [TensorLike , TensorLike ],
761+ b : TensorLike ,
762+ trans : bool = False ,
763+ b_ndim : int | None = None ,
764+ check_finite : bool = True ,
765+ ):
766+ """
767+ Solve a system of linear equations given the LU decomposition of the matrix.
768+
769+ Parameters
770+ ----------
771+ LU_and_pivots: tuple[TensorLike, TensorLike]
772+ LU decomposition of the matrix, as returned by `lu_factor`
773+ b: TensorLike
774+ Right-hand side of the equation
775+ trans: bool
776+ If True, solve A^T x = b, instead of Ax = b. Default is False
777+ b_ndim: int, optional
778+ The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
779+ of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
780+ check_finite: bool
781+ If True, check that the input matrices contain only finite numbers. Default is True.
782+ """
783+ b_ndim = _default_b_ndim (b , b_ndim )
784+ LU , pivots = LU_and_pivots
785+
786+ LU , pivots , b = map (pt .as_tensor_variable , [LU , pivots , b ])
787+ inv_permutation = _pivot_to_permutation (pivots )
788+
789+ x = b [inv_permutation ] if not trans else b
790+
791+ x = solve_triangular (
792+ LU ,
793+ x ,
794+ lower = not trans ,
795+ unit_diagonal = not trans ,
796+ trans = trans ,
797+ b_ndim = b_ndim ,
798+ check_finite = check_finite ,
799+ )
800+
801+ x = solve_triangular (
802+ LU ,
803+ x ,
804+ lower = trans ,
805+ unit_diagonal = trans ,
806+ trans = trans ,
807+ b_ndim = b_ndim ,
808+ check_finite = check_finite ,
809+ )
810+ x = x [pt .argsort (inv_permutation )] if trans else x
811+
812+ return LUSolve (
813+ inputs = [LU , pivots , b ],
814+ outputs = [x ],
815+ trans = trans ,
816+ b_ndim = b_ndim ,
817+ check_finite = check_finite ,
818+ )(LU , pivots , b )
819+
820+
677821class SolveTriangular (SolveBase ):
678822 """Solve a system of linear equations."""
679823
@@ -688,6 +832,9 @@ class SolveTriangular(SolveBase):
688832 def __init__ (self , * , unit_diagonal = False , ** kwargs ):
689833 if kwargs .get ("overwrite_a" , False ):
690834 raise ValueError ("overwrite_a is not supported for SolverTriangulare" )
835+
836+ # There's a naming inconsistency between solve_triangular (trans) and solve (transposed). Internally, we can use
837+ # transpose everywhere, but expose the same API as scipy.linalg.solve_triangular
691838 super ().__init__ (** kwargs )
692839 self .unit_diagonal = unit_diagonal
693840
@@ -1546,4 +1693,5 @@ def block_diag(*matrices: TensorVariable):
15461693 "cho_solve" ,
15471694 "lu" ,
15481695 "lu_factor" ,
1696+ "lu_solve" ,
15491697]
0 commit comments