1515from pytensor .tensor import TensorLike , as_tensor_variable
1616from pytensor .tensor import basic as ptb
1717from pytensor .tensor import math as ptm
18+ from pytensor .tensor .basic import diagonal
1819from pytensor .tensor .blockwise import Blockwise
1920from pytensor .tensor .nlinalg import kron , matrix_dot
2021from pytensor .tensor .shape import reshape
@@ -260,10 +261,10 @@ def make_node(self, A, b):
260261 raise ValueError (f"`b` must have { self .b_ndim } dims; got { b .type } instead." )
261262
262263 # Infer dtype by solving the most simple case with 1x1 matrices
263- inp_arr = [ np . eye ( 1 ). astype ( A . dtype ), np . eye ( 1 ). astype ( b . dtype )]
264- out_arr = [[ None ]]
265- self . perform ( None , inp_arr , out_arr )
266- o_dtype = out_arr [ 0 ][ 0 ] .dtype
264+ o_dtype = scipy_linalg . solve (
265+ np . ones (( 1 , 1 ), dtype = A . dtype ),
266+ np . ones (( 1 ,), dtype = b . dtype ),
267+ ) .dtype
267268 x = tensor (dtype = o_dtype , shape = b .type .shape )
268269 return Apply (self , [A , b ], [x ])
269270
@@ -315,7 +316,7 @@ def _default_b_ndim(b, b_ndim):
315316
316317 b = as_tensor_variable (b )
317318 if b_ndim is None :
318- return min (b .ndim , 2 ) # By default assume the core case is a matrix
319+ return min (b .ndim , 2 ) # By default, assume the core case is a matrix
319320
320321
321322class CholeskySolve (SolveBase ):
@@ -499,8 +500,33 @@ class Solve(SolveBase):
499500 )
500501
501502 def __init__ (self , * , assume_a = "gen" , ** kwargs ):
502- if assume_a not in ("gen" , "sym" , "her" , "pos" ):
503- raise ValueError (f"{ assume_a } is not a recognized matrix structure" )
503+ # Triangular and diagonal are handled outside of Solve
504+ valid_options = ["gen" , "sym" , "her" , "pos" , "tridiagonal" , "banded" ]
505+
506+ assume_a = assume_a .lower ()
507+ # We use the old names as the different dispatches are more likely to support them
508+ long_to_short = {
509+ "general" : "gen" ,
510+ "symmetric" : "sym" ,
511+ "hermitian" : "her" ,
512+ "positive definite" : "pos" ,
513+ }
514+ assume_a = long_to_short .get (assume_a , assume_a )
515+
516+ if assume_a not in valid_options :
517+ raise ValueError (
518+ f"Invalid assume_a: { assume_a } . It must be one of { valid_options } or { list (long_to_short .keys ())} "
519+ )
520+
521+ if assume_a in ("tridiagonal" , "banded" ):
522+ from scipy import __version__ as sp_version
523+
524+ if tuple (map (int , sp_version .split ("." )[:- 1 ])) < (1 , 15 ):
525+ warnings .warn (
526+ f"assume_a={ assume_a } requires scipy>=1.5.0. Defaulting to assume_a='gen'." ,
527+ UserWarning ,
528+ )
529+ assume_a = "gen"
504530
505531 super ().__init__ (** kwargs )
506532 self .assume_a = assume_a
@@ -536,10 +562,12 @@ def solve(
536562 a ,
537563 b ,
538564 * ,
539- assume_a = "gen" ,
540- lower = False ,
541- transposed = False ,
542- check_finite = True ,
565+ lower : bool = False ,
566+ overwrite_a : bool = False ,
567+ overwrite_b : bool = False ,
568+ check_finite : bool = True ,
569+ assume_a : str = "gen" ,
570+ transposed : bool = False ,
543571 b_ndim : int | None = None ,
544572):
545573 """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
@@ -548,14 +576,19 @@ def solve(
548576 corresponding string to ``assume_a`` key chooses the dedicated solver.
549577 The available options are
550578
551- =================== ========
552- generic matrix 'gen'
553- symmetric 'sym'
554- hermitian 'her'
555- positive definite 'pos'
556- =================== ========
579+ =================== ================================
580+ diagonal 'diagonal'
581+ tridiagonal 'tridiagonal'
582+ banded 'banded'
583+ upper triangular 'upper triangular'
584+ lower triangular 'lower triangular'
585+ symmetric 'symmetric' (or 'sym')
586+ hermitian 'hermitian' (or 'her')
587+ positive definite 'positive definite' (or 'pos')
588+ general 'general' (or 'gen')
589+ =================== ================================
557590
558- If omitted, ``'gen '`` is the default structure.
591+ If omitted, ``'general '`` is the default structure.
559592
560593 The datatype of the arrays define which solver is called regardless
561594 of the values. In other words, even when the complex array entries have
@@ -568,23 +601,52 @@ def solve(
568601 Square input data
569602 b : (..., N, NRHS) array_like
570603 Input data for the right hand side.
571- lower : bool, optional
572- If True, use only the data contained in the lower triangle of `a`. Default
573- is to use upper triangle. (ignored for ``'gen'``)
574- transposed: bool, optional
575- If True, solves the system A^T x = b. Default is False.
604+ lower : bool, default False
605+ Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
606+ If True, the calculation uses only the data in the lower triangle of `a`;
607+ entries above the diagonal are ignored. If False (default), the
608+ calculation uses only the data in the upper triangle of `a`; entries
609+ below the diagonal are ignored.
610+ overwrite_a : bool
611+ Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
612+ overwrite_b : bool
613+ Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
576614 check_finite : bool, optional
577615 Whether to check that the input matrices contain only finite numbers.
578616 Disabling may give a performance gain, but may result in problems
579617 (crashes, non-termination) if the inputs do contain infinities or NaNs.
580618 assume_a : str, optional
581619 Valid entries are explained above.
620+ transposed: bool, default False
621+ If True, solves the system A^T x = b. Default is False.
582622 b_ndim : int
583623 Whether the core case of b is a vector (1) or matrix (2).
584624 This will influence how batched dimensions are interpreted.
625+ By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1.
585626 """
627+ assume_a = assume_a .lower ()
628+
629+ if assume_a in ("lower triangular" , "upper triangular" ):
630+ lower = "lower" in assume_a
631+ return solve_triangular (
632+ a ,
633+ b ,
634+ lower = lower ,
635+ trans = transposed ,
636+ check_finite = check_finite ,
637+ b_ndim = b_ndim ,
638+ )
639+
586640 b_ndim = _default_b_ndim (b , b_ndim )
587641
642+ if assume_a == "diagonal" :
643+ a_diagonal = diagonal (a , axis1 = - 2 , axis2 = - 1 )
644+ b_transposed = b [None , :] if b_ndim == 1 else b .mT
645+ x = (b_transposed / pt .expand_dims (a_diagonal , - 2 )).mT
646+ if b_ndim == 1 :
647+ x = x .squeeze (- 1 )
648+ return x
649+
588650 if transposed :
589651 a = a .mT
590652 lower = not lower
0 commit comments