11import logging
22import typing
33import warnings
4- from collections .abc import Sequence
54from functools import reduce
65from typing import Literal , cast
76
109
1110import pytensor
1211import pytensor .tensor as pt
13- from pytensor import Variable
14- from pytensor .gradient import DisconnectedType
1512from pytensor .graph .basic import Apply
1613from pytensor .graph .op import Op
1714from pytensor .tensor import TensorLike , as_tensor_variable
2825
2926
3027class Cholesky (Op ):
28+ # TODO: LAPACK wrapper with in-place behavior, for solve also
29+
3130 __props__ = ("lower" , "check_finite" , "on_error" , "overwrite_a" )
3231 gufunc_signature = "(m,m)->(m,m)"
3332
@@ -397,186 +396,6 @@ def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None):
397396 )(A , b )
398397
399398
400- class LU (Op ):
401- """Decompose a matrix into lower and upper triangular matrices."""
402-
403- __props__ = ("permute_l" , "overwrite_a" , "check_finite" , "p_indices" )
404-
405- def __init__ (
406- self , * , permute_l = False , overwrite_a = False , check_finite = True , p_indices = False
407- ):
408- self .permute_l = permute_l
409- self .check_finite = check_finite
410- self .p_indices = p_indices
411- self .overwrite_a = overwrite_a
412-
413- if self .permute_l :
414- # permute_l overrides p_indices in the scipy function. We can copy that behavior
415- self .gufunc_signature = "(m,m)->(m,m),(m,m)"
416- elif self .p_indices :
417- self .gufunc_signature = "(m,m)->(m),(m,m),(m,m)"
418- else :
419- self .gufunc_signature = "(m,m)->(m,m),(m,m),(m,m)"
420-
421- if self .overwrite_a :
422- self .destroy_map = {0 : [0 ]}
423-
424- def infer_shape (self , fgraph , node , shapes ):
425- n = shapes [0 ][0 ]
426- if self .permute_l :
427- return [(n , n ), (n , n )]
428- elif self .p_indices :
429- return [(n ,), (n , n ), (n , n )]
430- else :
431- return [(n , n ), (n , n ), (n , n )]
432-
433- def make_node (self , x ):
434- x = as_tensor_variable (x )
435- if x .type .ndim != 2 :
436- raise TypeError (
437- f"LU only allowed on matrix (2-D) inputs, got { x .type .ndim } -D input"
438- )
439-
440- real_dtype = "f" if np .dtype (x .type .dtype ).char in "fF" else "d"
441- p_dtype = "int32" if self .p_indices else np .dtype (real_dtype )
442-
443- L = tensor (shape = x .type .shape , dtype = real_dtype )
444- U = tensor (shape = x .type .shape , dtype = real_dtype )
445-
446- if self .permute_l :
447- # In this case, L is actually P @ L
448- return Apply (self , inputs = [x ], outputs = [L , U ])
449- elif self .p_indices :
450- p = tensor (shape = (x .type .shape [0 ],), dtype = p_dtype )
451- return Apply (self , inputs = [x ], outputs = [p , L , U ])
452- else :
453- P = tensor (shape = x .type .shape , dtype = p_dtype )
454- return Apply (self , inputs = [x ], outputs = [P , L , U ])
455-
456- def perform (self , node , inputs , outputs ):
457- [A ] = inputs
458-
459- out = scipy .linalg .lu (
460- A ,
461- permute_l = self .permute_l ,
462- overwrite_a = self .overwrite_a ,
463- check_finite = self .check_finite ,
464- p_indices = self .p_indices ,
465- )
466-
467- outputs [0 ][0 ] = out [0 ]
468- outputs [1 ][0 ] = out [1 ]
469-
470- if not self .permute_l :
471- # In all cases except permute_l, there are three returns
472- outputs [2 ][0 ] = out [2 ]
473-
474- def inplace_on_inputs (self , allowed_inplace_inputs : list [int ]) -> "Op" :
475- if 0 in allowed_inplace_inputs :
476- new_props = self ._props_dict () # type: ignore
477- new_props ["overwrite_a" ] = True
478- return type (self )(** new_props )
479- else :
480- return self
481-
482- def L_op (
483- self ,
484- inputs : Sequence [Variable ],
485- outputs : Sequence [Variable ],
486- output_grads : Sequence [Variable ],
487- ) -> list [Variable ]:
488- r"""
489- Derivation is due to Differentiation of Matrix Functionals Using Triangular Factorization
490- F. R. De Hoog, R.S. Anderssen, M. A. Lukas
491- """
492- [A ] = inputs
493- A = cast (TensorVariable , A )
494-
495- if self .permute_l :
496- PL_bar , U_bar = output_grads
497-
498- # TODO: Rewrite into permute_l = False for graphs where we need to compute the gradient
499- P , L , U = lu ( # type: ignore
500- A , permute_l = False , check_finite = self .check_finite , p_indices = False
501- )
502-
503- # Permutation matrix is orthogonal
504- L_bar = (
505- P .T @ PL_bar
506- if not isinstance (PL_bar .type , DisconnectedType )
507- else pt .zeros_like (A )
508- )
509-
510- elif self .p_indices :
511- p , L , U = outputs
512-
513- # TODO: rewrite to p_indices = False for graphs where we need to compute the gradient
514- P = pt .eye (A .shape [0 ])[p ]
515- _ , L_bar , U_bar = output_grads
516- else :
517- P , L , U = outputs
518- _ , L_bar , U_bar = output_grads
519-
520- L_bar = (
521- L_bar if not isinstance (L_bar .type , DisconnectedType ) else pt .zeros_like (A )
522- )
523- U_bar = (
524- U_bar if not isinstance (U_bar .type , DisconnectedType ) else pt .zeros_like (A )
525- )
526-
527- x1 = ptb .tril (L .T @ L_bar , k = - 1 )
528- x2 = ptb .triu (U_bar @ U .T )
529-
530- L_inv_x = solve_triangular (L .T , x1 + x2 , lower = False , unit_diagonal = True )
531- A_bar = P @ solve_triangular (U , L_inv_x .T , lower = False ).T
532-
533- return [A_bar ]
534-
535-
536- def lu (
537- a : TensorLike , permute_l = False , check_finite = True , p_indices = False
538- ) -> (
539- tuple [TensorVariable , TensorVariable , TensorVariable ]
540- | tuple [TensorVariable , TensorVariable ]
541- ):
542- """
543- Factorize a matrix as the product of a unit lower triangular matrix and an upper triangular matrix:
544-
545- ... math::
546-
547- A = P L U
548-
549- Where P is a permutation matrix, L is lower triangular with unit diagonal elements, and U is upper triangular.
550-
551- Parameters
552- ----------
553- a: TensorLike
554- Matrix to be factorized
555- permute_l: bool
556- If True, L is a product of permutation and unit lower triangular matrices. Only two values, PL and U, will
557- be returned in this case, and PL will not be lower triangular.
558- check_finite: bool
559- Whether to check that the input matrix contains only finite numbers.
560- p_indices: bool
561- If True, return integer matrix indices for the permutation matrix. Otherwise, return the permutation matrix
562- itself.
563-
564- Returns
565- -------
566- P: TensorVariable
567- Permutation matrix, or array of integer indices for permutation matrix. Not returned if permute_l is True.
568- L: TensorVariable
569- Lower triangular matrix, or product of permutation and unit lower triangular matrices if permute_l is True.
570- U: TensorVariable
571- Upper triangular matrix
572- """
573- return cast (
574- tuple [TensorVariable , TensorVariable , TensorVariable ]
575- | tuple [TensorVariable , TensorVariable ],
576- LU (permute_l = permute_l , check_finite = check_finite , p_indices = p_indices )(a ),
577- )
578-
579-
580399class SolveTriangular (SolveBase ):
581400 """Solve a system of linear equations."""
582401
@@ -734,7 +553,6 @@ def solve(
734553 assume_a = "gen" ,
735554 lower = False ,
736555 check_finite = True ,
737- transposed = False ,
738556 b_ndim : int | None = None ,
739557):
740558 """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
@@ -772,8 +590,6 @@ def solve(
772590 (crashes, non-termination) if the inputs do contain infinities or NaNs.
773591 assume_a : str, optional
774592 Valid entries are explained above.
775- transposed: bool, optional
776- If True, solve ``A.T @ x = b``
777593 b_ndim : int
778594 Whether the core case of b is a vector (1) or matrix (2).
779595 This will influence how batched dimensions are interpreted.
@@ -785,7 +601,6 @@ def solve(
785601 check_finite = check_finite ,
786602 assume_a = assume_a ,
787603 b_ndim = b_ndim ,
788- transposed = transposed ,
789604 )
790605 )(a , b )
791606
0 commit comments