11import logging
2- import typing
32import warnings
43from functools import reduce
5- from typing import Literal , cast
4+ from typing import Literal , cast , Sequence
65
76import numpy as np
8- import scipy . linalg
7+ import scipy
98
109import pytensor
1110import pytensor .tensor as pt
11+ from pytensor .gradient import DisconnectedType
1212from pytensor .graph .basic import Apply
1313from pytensor .graph .op import Op
1414from pytensor .tensor import TensorLike , as_tensor_variable
@@ -302,6 +302,7 @@ def L_op(self, inputs, outputs, output_gradients):
302302 }
303303 )
304304 b_bar = trans_solve_op (A .T , c_bar )
305+
305306 # force outer product if vector second input
306307 A_bar = - ptm .outer (b_bar , c ) if c .ndim == 1 else - b_bar .dot (c .T )
307308
@@ -369,7 +370,7 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
369370 Whether to check that the input matrices contain only finite numbers.
370371 Disabling may give a performance gain, but may result in problems
371372 (crashes, non-termination) if the inputs do contain infinities or NaNs.
372- b_ndim : int
373+ b_ndim : int
373374 Whether the core case of b is a vector (1) or matrix (2).
374375 This will influence how batched dimensions are interpreted.
375376 """
@@ -380,6 +381,186 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
380381 )(A , b )
381382
382383
384+ class LU (Op ):
385+ """Decompose a matrix into lower and upper triangular matrices."""
386+
387+ __props__ = ("permute_l" , "overwrite_a" , "check_finite" , "p_indices" )
388+
389+ def __init__ (
390+ self , * , permute_l = False , overwrite_a = False , check_finite = True , p_indices = False
391+ ):
392+ self .permute_l = permute_l
393+ self .check_finite = check_finite
394+ self .p_indices = p_indices
395+ self .overwrite_a = overwrite_a
396+
397+ if self .permute_l :
398+ # permute_l overrides p_indices in the scipy function. We can copy that behavior
399+ self .gufunc_signature = "(m,m)->(m,m),(m,m)"
400+ elif self .p_indices :
401+ self .gufunc_signature = "(m,m)->(m),(m,m),(m,m)"
402+ else :
403+ self .gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)"
404+
405+ if self .overwrite_a :
406+ self .destroy_map = {0 : [0 ]}
407+
408+ def infer_shape (self , fgraph , node , shapes ):
409+ n = shapes [0 ][0 ]
410+ if self .permute_l :
411+ return [(n , n ), (n , n )]
412+ elif self .p_indices :
413+ return [(n ,), (n , n ), (n , n )]
414+ else :
415+ return [(n , n ), (n , n ), (n , n )]
416+
417+ def make_node (self , x ):
418+ x = as_tensor_variable (x )
419+ if x .type .ndim != 2 :
420+ raise TypeError (
421+ f"LU only allowed on matrix (2-D) inputs, got { x .type .ndim } -D input"
422+ )
423+
424+ real_dtype = "f" if np .dtype (x .type .dtype ).char in "fF" else "d"
425+ p_dtype = "int32" if self .p_indices else np .dtype (real_dtype )
426+
427+ L = tensor (shape = x .type .shape , dtype = real_dtype )
428+ U = tensor (shape = x .type .shape , dtype = real_dtype )
429+
430+ if self .permute_l :
431+ # In this case, L is actually P @ L
432+ return Apply (self , inputs = [x ], outputs = [L , U ])
433+ elif self .p_indices :
434+ p = tensor (shape = (x .type .shape [0 ],), dtype = p_dtype )
435+ return Apply (self , inputs = [x ], outputs = [p , L , U ])
436+ else :
437+ P = tensor (shape = x .type .shape , dtype = p_dtype )
438+ return Apply (self , inputs = [x ], outputs = [P , L , U ])
439+
440+ def perform (self , node , inputs , outputs ):
441+ [A ] = inputs
442+
443+ out = scipy .linalg .lu (
444+ A ,
445+ permute_l = self .permute_l ,
446+ overwrite_a = self .overwrite_a ,
447+ check_finite = self .check_finite ,
448+ p_indices = self .p_indices ,
449+ )
450+
451+ outputs [0 ][0 ] = out [0 ]
452+ outputs [1 ][0 ] = out [1 ]
453+
454+ if not self .permute_l :
455+ # In all cases except permute_l, there are three returns
456+ outputs [2 ][0 ] = out [2 ]
457+
458+ def inplace_on_inputs (self , allowed_inplace_inputs : list [int ]) -> "Op" :
459+ if 0 in allowed_inplace_inputs :
460+ new_props = self ._props_dict () # type: ignore
461+ new_props ["overwrite_a" ] = True
462+ return type (self )(** new_props )
463+ else :
464+ return self
465+
466+ def L_op (
467+ self ,
468+ inputs : Sequence [ptb .Variable ],
469+ outputs : Sequence [ptb .Variable ],
470+ output_grads : Sequence [ptb .Variable ],
471+ ) -> list [ptb .Variable ]:
472+ r"""
473+ Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization
474+ F. R. De Hoog, R.S. Anderssen, M. A. Lukas
475+ """
476+ [A ] = inputs
477+ A = cast (TensorVariable , A )
478+
479+ if self .permute_l :
480+ PL_bar , U_bar = output_grads
481+
482+ # TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
483+ P , L , U = lu ( # type: ignore
484+ A , permute_l = False , check_finite = self .check_finite , p_indices = False
485+ )
486+
487+ # Permutation matrix is orthogonal
488+ L_bar = (
489+ P .T @ PL_bar
490+ if not isinstance (PL_bar .type , DisconnectedType )
491+ else pt .zeros_like (A )
492+ )
493+
494+ elif self .p_indices :
495+ p , L , U = outputs
496+
497+ # TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
498+ P = pt .eye (A .shape [0 ])[p ]
499+ _ , L_bar , U_bar = output_grads
500+ else :
501+ P , L , U = outputs
502+ _ , L_bar , U_bar = output_grads
503+
504+ L_bar = (
505+ L_bar if not isinstance (L_bar .type , DisconnectedType ) else pt .zeros_like (A )
506+ )
507+ U_bar = (
508+ U_bar if not isinstance (U_bar .type , DisconnectedType ) else pt .zeros_like (A )
509+ )
510+
511+ x1 = ptb .tril (L .T @ L_bar , k = - 1 )
512+ x2 = ptb .triu (U_bar @ U .T )
513+
514+ L_inv_x = solve_triangular (L .T , x1 + x2 , lower = False , unit_diagonal = True )
515+ A_bar = P @ solve_triangular (U , L_inv_x .T , lower = False ).T
516+
517+ return [A_bar ]
518+
519+
520+ def lu (
521+ a : TensorLike , permute_l = False , check_finite = True , p_indices = False
522+ ) -> (
523+ tuple [TensorVariable , TensorVariable , TensorVariable ]
524+ | tuple [TensorVariable , TensorVariable ]
525+ ):
526+ """
527+ Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix:
528+
529+ ... math::
530+
531+ A = P L U
532+
533+ Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular.
534+
535+ Parameters
536+ ----------
537+ a: TensorLike
538+ Matrix to be factorized
539+ permute_l: bool
540+ If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
541+ be returned in this case, and PL will not be lower triangular.
542+ check_finite: bool
543+ Whether to check that the input matrix contains only finite numbers.
544+ p_indices: bool
545+ If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
546+ itself.
547+
548+ Returns
549+ -------
550+ P: TensorVariable
551+ Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True.
552+ L: TensorVariable
553+ Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True.
554+ U: TensorVariable
555+ Upper triangular matrix
556+ """
557+ return cast (
558+ tuple [TensorVariable , TensorVariable , TensorVariable ]
559+ | tuple [TensorVariable , TensorVariable ],
560+ LU (permute_l = permute_l , check_finite = check_finite , p_indices = p_indices )(a ),
561+ )
562+
563+
383564class SolveTriangular (SolveBase ):
384565 """Solve a system of linear equations."""
385566
@@ -1064,7 +1245,7 @@ def solve_discrete_are(
10641245 )
10651246
10661247
1067- def _largest_common_dtype (tensors : typing . Sequence [TensorVariable ]) -> np .dtype :
1248+ def _largest_common_dtype (tensors : Sequence [TensorVariable ]) -> np .dtype :
10681249 return reduce (lambda l , r : np .promote_types (l , r ), [x .dtype for x in tensors ])
10691250
10701251
@@ -1175,4 +1356,5 @@ def block_diag(*matrices: TensorVariable):
11751356 "solve_discrete_are" ,
11761357 "solve_triangular" ,
11771358 "block_diag" ,
1359+ "cho_solve" ,
11781360]
0 commit comments