55from typing import Literal , cast
66
77import numpy as np
8+ import scipy
89import scipy .linalg as scipy_linalg
910from numpy .exceptions import ComplexWarning
11+ from packaging .version import parse as parse_version
1012
1113import pytensor
1214import pytensor .tensor as pt
1517from pytensor .tensor import TensorLike , as_tensor_variable
1618from pytensor .tensor import basic as ptb
1719from pytensor .tensor import math as ptm
20+ from pytensor .tensor .basic import diagonal
1821from pytensor .tensor .blockwise import Blockwise
1922from pytensor .tensor .nlinalg import kron , matrix_dot
2023from pytensor .tensor .shape import reshape
@@ -260,10 +263,10 @@ def make_node(self, A, b):
260263 raise ValueError (f"`b` must have { self .b_ndim } dims; got { b .type } instead." )
261264
262265 # Infer dtype by solving the most simple case with 1x1 matrices
263- inp_arr = [ np . eye ( 1 ). astype ( A . dtype ), np . eye ( 1 ). astype ( b . dtype )]
264- out_arr = [[ None ]]
265- self . perform ( None , inp_arr , out_arr )
266- o_dtype = out_arr [ 0 ][ 0 ] .dtype
266+ o_dtype = scipy_linalg . solve (
267+ np . eye ( 1 ). astype ( A . dtype ),
268+ np . eye ( 1 ). astype ( b . dtype ),
269+ ) .dtype
267270 x = tensor (dtype = o_dtype , shape = b .type .shape )
268271 return Apply (self , [A , b ], [x ])
269272
@@ -315,7 +318,7 @@ def _default_b_ndim(b, b_ndim):
315318
316319 b = as_tensor_variable (b )
317320 if b_ndim is None :
318- return min (b .ndim , 2 ) # By default assume the core case is a matrix
321+ return min (b .ndim , 2 ) # By default, assume the core case is a matrix
319322
320323
321324class CholeskySolve (SolveBase ):
@@ -332,6 +335,19 @@ def __init__(self, **kwargs):
332335 kwargs .setdefault ("lower" , True )
333336 super ().__init__ (** kwargs )
334337
338+ def make_node (self , * inputs ):
339+ # Allow base class to do input validation
340+ super_apply = super ().make_node (* inputs )
341+ A , b = super_apply .inputs
342+ [super_out ] = super_apply .outputs
343+ # The dtype of chol_solve does not match solve, which the base class checks
344+ dtype = scipy_linalg .cho_solve (
345+ np .eye (1 ).astype (A .dtype ),
346+ np .eye (1 ).astype (b .dtype ),
347+ ).dtype
348+ out = tensor (dtype = dtype , shape = super_out .type .shape )
349+ return Apply (self , [A , b ], [out ])
350+
335351 def perform (self , node , inputs , output_storage ):
336352 C , b = inputs
337353 rval = scipy_linalg .cho_solve (
@@ -499,8 +515,32 @@ class Solve(SolveBase):
499515 )
500516
501517 def __init__ (self , * , assume_a = "gen" , ** kwargs ):
502- if assume_a not in ("gen" , "sym" , "her" , "pos" ):
503- raise ValueError (f"{ assume_a } is not a recognized matrix structure" )
518+ # Triangular and diagonal are handled outside of Solve
519+ valid_options = ["gen" , "sym" , "her" , "pos" , "tridiagonal" , "banded" ]
520+
521+ assume_a = assume_a .lower ()
522+ # We use the old names as the different dispatches are more likely to support them
523+ if assume_a == "general" :
524+ assume_a = "gen"
525+ elif assume_a == "symmetric" :
526+ assume_a = "sym"
527+ elif assume_a == "hermitian" :
528+ assume_a = "her"
529+ elif assume_a == "positive definite" :
530+ assume_a = "pos"
531+ if assume_a not in valid_options :
532+ raise ValueError (
533+ f"Invalid assume_a: { assume_a } . It must be one of { valid_options } "
534+ )
535+
536+ if assume_a in ("tridiagonal" , "banded" ) and parse_version (
537+ scipy .__version__
538+ ) < parse_version ("1.15.0" ):
539+ warnings .warn (
540+ f"assume_a={ assume_a } requires scipy>=1.5.0. Defaulting to assume_a='gen'." ,
541+ UserWarning ,
542+ )
543+ assume_a = "gen"
504544
505545 super ().__init__ (** kwargs )
506546 self .assume_a = assume_a
@@ -536,10 +576,12 @@ def solve(
536576 a ,
537577 b ,
538578 * ,
539- assume_a = "gen" ,
540- lower = False ,
541- transposed = False ,
542- check_finite = True ,
579+ lower : bool = False ,
580+ overwrite_a : bool = False ,
581+ overwrite_b : bool = False ,
582+ check_finite : bool = True ,
583+ assume_a : str = "gen" ,
584+ transposed : bool = False ,
543585 b_ndim : int | None = None ,
544586):
545587 """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
@@ -548,14 +590,19 @@ def solve(
548590 corresponding string to ``assume_a`` key chooses the dedicated solver.
549591 The available options are
550592
551- =================== ========
552- generic matrix 'gen'
553- symmetric 'sym'
554- hermitian 'her'
555- positive definite 'pos'
556- =================== ========
593+ =================== ================================
594+ diagonal 'diagonal'
595+ tridiagonal 'tridiagonal'
596+ banded 'banded'
597+ upper triangular 'upper triangular'
598+ lower triangular 'lower triangular'
599+ symmetric 'symmetric' (or 'sym')
600+ hermitian 'hermitian' (or 'her')
601+ positive definite 'positive definite' (or 'pos')
602+ general 'general' (or 'gen')
603+ =================== ================================
557604
558- If omitted, ``'gen '`` is the default structure.
605+ If omitted, ``'general '`` is the default structure.
559606
560607 The datatype of the arrays define which solver is called regardless
561608 of the values. In other words, even when the complex array entries have
@@ -568,23 +615,52 @@ def solve(
568615 Square input data
569616 b : (..., N, NRHS) array_like
570617 Input data for the right hand side.
571- lower : bool, optional
572- If True, use only the data contained in the lower triangle of `a`. Default
573- is to use upper triangle. (ignored for ``'gen'``)
574- transposed: bool, optional
575- If True, solves the system A^T x = b. Default is False.
618+ lower : bool, default False
619+ Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
620+ If True, the calculation uses only the data in the lower triangle of `a`;
621+ entries above the diagonal are ignored. If False (default), the
622+ calculation uses only the data in the upper triangle of `a`; entries
623+ below the diagonal are ignored.
624+ overwrite_a : bool
625+ Ignored argument. PyTensor will perform the operation in-place if possible.
626+ overwrite_b : bool
627+ Ignored argument. PyTensor will perform the operation in-place if possible.
576628 check_finite : bool, optional
577629 Whether to check that the input matrices contain only finite numbers.
578630 Disabling may give a performance gain, but may result in problems
579631 (crashes, non-termination) if the inputs do contain infinities or NaNs.
580632 assume_a : str, optional
581633 Valid entries are explained above.
634+ transposed: bool, default False
635+ If True, solves the system A^T x = b. Default is False.
582636 b_ndim : int
583637 Whether the core case of b is a vector (1) or matrix (2).
584638 This will influence how batched dimensions are interpreted.
639+ By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1.
585640 """
641+ assume_a = assume_a .lower ()
642+
643+ if assume_a in ("lower triangular" , "upper triangular" ):
644+ lower = "lower" in assume_a
645+ return solve_triangular (
646+ a ,
647+ b ,
648+ lower = lower ,
649+ trans = transposed ,
650+ check_finite = check_finite ,
651+ b_ndim = b_ndim ,
652+ )
653+
586654 b_ndim = _default_b_ndim (b , b_ndim )
587655
656+ if assume_a == "diagonal" :
657+ a_diagonal = diagonal (a , axis1 = - 2 , axis2 = - 1 )
658+ b_transposed = b [None , :] if b_ndim == 1 else b .mT
659+ x = (b_transposed / pt .expand_dims (a_diagonal , - 2 )).mT
660+ if b_ndim == 1 :
661+ x = x .squeeze (- 1 )
662+ return x
663+
588664 if transposed :
589665 a = a .mT
590666 lower = not lower
0 commit comments