@@ -577,6 +577,103 @@ def lu(
577577 )
578578
579579
580+ class LUFactor (Op ):
581+ __props__ = ("overwrite_a" , "check_finite" )
582+
583+ def __init__ (self , * , overwrite_a = False , check_finite = True ):
584+ self .overwrite_a = overwrite_a
585+ self .check_finite = check_finite
586+ self .gufunc_signature = "(m,m)->(m,m),(m)"
587+
588+ if self .overwrite_a :
589+ self .destroy_map = {0 : [0 ]}
590+
591+ def make_node (self , A ):
592+ A = as_tensor_variable (A )
593+ if A .type .ndim != 2 :
594+ raise TypeError (
595+ f"LU only allowed on matrix (2-D) inputs, got { A .type .ndim } -D input"
596+ )
597+
598+ LU = matrix (shape = A .type .shape , dtype = A .type .dtype )
599+ pivots = vector (shape = (A .type .shape [0 ],), dtype = "int32" )
600+ return Apply (self , [A ], [LU , pivots ])
601+
602+ def infer_shape (self , fgraph , node , shapes ):
603+ n = shapes [0 ][0 ]
604+ return [(n , n ), (n ,)]
605+
606+ def inplace_on_inputs (self , allowed_inplace_inputs : list [int ]) -> "Op" :
607+ if 0 in allowed_inplace_inputs :
608+ new_props = self ._props_dict () # type: ignore
609+ new_props ["overwrite_a" ] = True
610+ return type (self )(** new_props )
611+ else :
612+ return self
613+
614+ def perform (self , node , inputs , outputs ):
615+ A = inputs [0 ]
616+ LU , pivots = scipy_linalg .lu_factor (
617+ A ,
618+ overwrite_a = self .overwrite_a ,
619+ check_finite = self .check_finite ,
620+ )
621+
622+ outputs [0 ][0 ] = LU
623+ outputs [1 ][0 ] = pivots
624+
625+ def L_op (self , inputs , outputs , output_gradients ):
626+ A = inputs [0 ]
627+ LU_bar , _ = output_gradients
628+
629+ # We need the permutation matrix P, not the pivot indices. Easiest way is to just do another LU forward.
630+ # Alternative is to do a scan over the pivot indices to convert them to permutation indices. I don't know if
631+ # that's faster or slower.
632+ P , L , U = lu (
633+ A , permute_l = False , check_finite = self .check_finite , p_indices = False
634+ )
635+
636+ # Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
637+ L_bar = ptb .tril (LU_bar , k = - 1 )
638+ U_bar = ptb .triu (LU_bar )
639+
640+ # From here we're in the same situation as the LU gradient derivation
641+ x1 = ptb .tril (L .T @ L_bar , k = - 1 )
642+ x2 = ptb .triu (U_bar @ U .T )
643+
644+ LT_inv_x = solve_triangular (L .T , x1 + x2 , lower = False , unit_diagonal = True )
645+ A_bar = P @ solve_triangular (U , LT_inv_x .T , lower = False ).T
646+
647+ return [A_bar ]
648+
649+
650+ def lu_factor (
651+ a : TensorLike , * , check_finite = True
652+ ) -> tuple [TensorVariable , TensorVariable ]:
653+ """
654+ LU factorization with partial pivoting.
655+
656+ Parameters
657+ ----------
658+ a: TensorLike
659+ Matrix to be factorized
660+ check_finite: bool
661+ Whether to check that the input matrix contains only finite numbers.
662+
663+ Returns
664+ -------
665+ LU: TensorVariable
666+ LU decomposition of `a`
667+ pivots: TensorVariable
668+ Permutation indices
669+ """
670+
671+ return cast (
672+ tuple [TensorVariable , TensorVariable ],
673+ Blockwise (LUFactor (check_finite = check_finite ))(a ),
674+ )
675+
676+
580677class SolveTriangular (SolveBase ):
581678 """Solve a system of linear equations."""
582679
@@ -1448,4 +1545,5 @@ def block_diag(*matrices: TensorVariable):
14481545 "block_diag" ,
14491546 "cho_solve" ,
14501547 "lu" ,
1548+ "lu_factor" ,
14511549]
0 commit comments