@@ -565,6 +565,103 @@ def lu(
565565 )
566566
567567
568+ class LUFactor (Op ):
569+ __props__ = ("overwrite_a" , "check_finite" )
570+
571+ def __init__ (self , * , overwrite_a = False , check_finite = True ):
572+ self .overwrite_a = overwrite_a
573+ self .check_finite = check_finite
574+ self .gufunc_signature = "(m,m)->(m,m),(m)"
575+
576+ if self .overwrite_a :
577+ self .destroy_map = {0 : [0 ]}
578+
579+ def make_node (self , A ):
580+ A = as_tensor_variable (A )
581+ if A .type .ndim != 2 :
582+ raise TypeError (
583+ f"LU only allowed on matrix (2-D) inputs, got { A .type .ndim } -D input"
584+ )
585+
586+ LU = matrix (shape = A .type .shape , dtype = A .type .dtype )
587+ pivots = vector (shape = (A .type .shape [0 ],), dtype = "int32" )
588+ return Apply (self , [A ], [LU , pivots ])
589+
590+ def infer_shape (self , fgraph , node , shapes ):
591+ n = shapes [0 ][0 ]
592+ return [(n , n ), (n ,)]
593+
594+ def inplace_on_inputs (self , allowed_inplace_inputs : list [int ]) -> "Op" :
595+ if 0 in allowed_inplace_inputs :
596+ new_props = self ._props_dict () # type: ignore
597+ new_props ["overwrite_a" ] = True
598+ return type (self )(** new_props )
599+ else :
600+ return self
601+
602+ def perform (self , node , inputs , outputs ):
603+ A = inputs [0 ]
604+ LU , pivots = scipy_linalg .lu_factor (
605+ A ,
606+ overwrite_a = self .overwrite_a ,
607+ check_finite = self .check_finite ,
608+ )
609+
610+ outputs [0 ][0 ] = LU
611+ outputs [1 ][0 ] = pivots
612+
613+ def L_op (self , inputs , outputs , output_gradients ):
614+ A = inputs [0 ]
615+ LU_bar , _ = output_gradients
616+
617+ # We need the permutation matrix P, not the pivot indices. Easiest way is to just do another LU forward.
618+ # Alternative is to do a scan over the pivot indices to convert them to permutation indices. I don't know if
619+ # that's faster or slower.
620+ P , L , U = lu (
621+ A , permute_l = False , check_finite = self .check_finite , p_indices = False
622+ )
623+
624+ # Split LU_bar into L_bar and U_bar. This is valid because of the triangular structure of L and U
625+ L_bar = ptb .tril (LU_bar , k = - 1 )
626+ U_bar = ptb .triu (LU_bar )
627+
628+ # From here we're in the same situation as the LU gradient derivation
629+ x1 = ptb .tril (L .T @ L_bar , k = - 1 )
630+ x2 = ptb .triu (U_bar @ U .T )
631+
632+ LT_inv_x = solve_triangular (L .T , x1 + x2 , lower = False , unit_diagonal = True )
633+ A_bar = P @ solve_triangular (U , LT_inv_x .T , lower = False ).T
634+
635+ return [A_bar ]
636+
637+
638+ def lu_factor (
639+ a : TensorLike , * , check_finite = True
640+ ) -> tuple [TensorVariable , TensorVariable ]:
641+ """
642+ LU factorization with partial pivoting.
643+
644+ Parameters
645+ ----------
646+ a: TensorLike
647+ Matrix to be factorized
648+ check_finite: bool
649+ Whether to check that the input matrix contains only finite numbers.
650+
651+ Returns
652+ -------
653+ LU: TensorVariable
654+ LU decomposition of `a`
655+ pivots: TensorVariable
656+ Permutation indices
657+ """
658+
659+ return cast (
660+ tuple [TensorVariable , TensorVariable ],
661+ Blockwise (LUFactor (check_finite = check_finite ))(a ),
662+ )
663+
664+
568665class SolveTriangular (SolveBase ):
569666 """Solve a system of linear equations."""
570667
@@ -1362,4 +1459,5 @@ def block_diag(*matrices: TensorVariable):
13621459 "block_diag" ,
13631460 "cho_solve" ,
13641461 "lu" ,
1462+ "lu_factor" ,
13651463]
0 commit comments