11import logging
22import warnings
33from collections .abc import Sequence
4- from functools import reduce
4+ from functools import partial , reduce
55from typing import Literal , cast
66
77import numpy as np
@@ -589,6 +589,7 @@ def lu(
589589
590590
591591class PivotToPermutations (Op ):
592+ gufunc_signature = "(x)->(x)"
592593 __props__ = ("inverse" ,)
593594
594595 def __init__ (self , inverse = True ):
@@ -723,40 +724,22 @@ def lu_factor(
723724 )
724725
725726
726- def lu_solve (
727- LU_and_pivots : tuple [TensorLike , TensorLike ],
727+ def _lu_solve (
728+ LU : TensorLike ,
729+ pivots : TensorLike ,
728730 b : TensorLike ,
729731 trans : bool = False ,
730732 b_ndim : int | None = None ,
731733 check_finite : bool = True ,
732- overwrite_b : bool = False ,
733734):
734- """
735- Solve a system of linear equations given the LU decomposition of the matrix.
736-
737- Parameters
738- ----------
739- LU_and_pivots: tuple[TensorLike, TensorLike]
740- LU decomposition of the matrix, as returned by `lu_factor`
741- b: TensorLike
742- Right-hand side of the equation
743- trans: bool
744- If True, solve A^T x = b, instead of Ax = b. Default is False
745- b_ndim: int, optional
746- The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
747- of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
748- check_finite: bool
749- If True, check that the input matrices contain only finite numbers. Default is True.
750- overwrite_b: bool
751- Ignored by Pytensor. Pytensor will always compute inplace when possible.
752- """
753735 b_ndim = _default_b_ndim (b , b_ndim )
754- LU , pivots = LU_and_pivots
755736
756737 LU , pivots , b = map (pt .as_tensor_variable , [LU , pivots , b ])
757- inv_permutation = pivot_to_permutation (pivots , inverse = True )
758738
739+ inv_permutation = pivot_to_permutation (pivots , inverse = True )
759740 x = b [inv_permutation ] if not trans else b
741+ # TODO: Use PermuteRows on b
742+ # x = permute_rows(b, pivots) if not trans else b
760743
761744 x = solve_triangular (
762745 LU ,
@@ -777,11 +760,52 @@ def lu_solve(
777760 b_ndim = b_ndim ,
778761 check_finite = check_finite ,
779762 )
780- x = x [pt .argsort (inv_permutation )] if trans else x
781763
764+ # TODO: Use PermuteRows(inverse=True) on x
765+ # if trans:
766+ # x = permute_rows(x, pivots, inverse=True)
767+ x = x [pt .argsort (inv_permutation )] if trans else x
782768 return x
783769
784770
771+ def lu_solve (
772+ LU_and_pivots : tuple [TensorLike , TensorLike ],
773+ b : TensorLike ,
774+ trans : bool = False ,
775+ b_ndim : int | None = None ,
776+ check_finite : bool = True ,
777+ overwrite_b : bool = False ,
778+ ):
779+ """
780+ Solve a system of linear equations given the LU decomposition of the matrix.
781+
782+ Parameters
783+ ----------
784+ LU_and_pivots: tuple[TensorLike, TensorLike]
785+ LU decomposition of the matrix, as returned by `lu_factor`
786+ b: TensorLike
787+ Right-hand side of the equation
788+ trans: bool
789+ If True, solve A^T x = b, instead of Ax = b. Default is False
790+ b_ndim: int, optional
791+ The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
792+ of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
793+ check_finite: bool
794+ If True, check that the input matrices contain only finite numbers. Default is True.
795+ overwrite_b: bool
796+ Ignored by Pytensor. Pytensor will always compute inplace when possible.
797+ """
798+ b_ndim = _default_b_ndim (b , b_ndim )
799+ if b_ndim == 1 :
800+ signature = "(m,m),(m),(m)->(m)"
801+ else :
802+ signature = "(m,m),(m),(m,n)->(m,n)"
803+ partialled_func = partial (
804+ _lu_solve , trans = trans , b_ndim = b_ndim , check_finite = check_finite
805+ )
806+ return pt .vectorize (partialled_func , signature = signature )(* LU_and_pivots , b )
807+
808+
785809class SolveTriangular (SolveBase ):
786810 """Solve a system of linear equations."""
787811
0 commit comments