Skip to content

Commit 15fb803

Browse files
committed
Enable new assume_a in Solve
1 parent 7ffaae7 commit 15fb803

File tree

3 files changed

+129
-33
lines changed

3 files changed

+129
-33
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import jax
24

35
from pytensor.link.jax.dispatch.basic import jax_funcify
@@ -39,13 +41,28 @@ def cholesky(a, lower=lower):
3941

4042
@jax_funcify.register(Solve)
4143
def jax_funcify_Solve(op, **kwargs):
42-
if op.assume_a != "gen" and op.lower:
43-
lower = True
44+
assume_a = op.assume_a
45+
lower = op.lower
46+
47+
if assume_a == "tridiagonal":
48+
# jax.scipy.solve does not yet support tridiagonal matrices
49+
# But there's a jax.lax.linalg.tridiaonal_solve we can use instead.
50+
def solve(a, b):
51+
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
52+
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
53+
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
54+
return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower)
55+
4456
else:
45-
lower = False
57+
if assume_a not in ("gen", "sym", "her", "pos"):
58+
warnings.warn(
59+
f"JAX solve does not support assume_a={op.assume_a}. Defaulting to assume_a='gen'.",
60+
UserWarning,
61+
)
62+
assume_a = "gen"
4663

47-
def solve(a, b, lower=lower):
48-
return jax.scipy.linalg.solve(a, b, lower=lower)
64+
def solve(a, b):
65+
return jax.scipy.linalg.solve(a, b, lower=lower, assume_a=assume_a)
4966

5067
return solve
5168

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Callable
23

34
import numba
@@ -1070,14 +1071,16 @@ def numba_funcify_Solve(op, node, **kwargs):
10701071
elif assume_a == "sym":
10711072
solve_fn = _solve_symmetric
10721073
elif assume_a == "her":
1073-
raise NotImplementedError(
1074-
'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, '
1075-
"please open an issue on github."
1076-
)
1074+
# We already ruled out complex inputs
1075+
solve_fn = _solve_symmetric
10771076
elif assume_a == "pos":
10781077
solve_fn = _solve_psd
10791078
else:
1080-
raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode")
1079+
warnings.warn(
1080+
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.",
1081+
UserWarning,
1082+
)
1083+
solve_fn = _solve_gen
10811084

10821085
@numba_basic.numba_njit(inline="always")
10831086
def solve(a, b):

pytensor/tensor/slinalg.py

Lines changed: 99 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from typing import Literal, cast
66

77
import numpy as np
8+
import scipy
89
import scipy.linalg as scipy_linalg
910
from numpy.exceptions import ComplexWarning
11+
from packaging.version import parse as parse_version
1012

1113
import pytensor
1214
import pytensor.tensor as pt
@@ -15,6 +17,7 @@
1517
from pytensor.tensor import TensorLike, as_tensor_variable
1618
from pytensor.tensor import basic as ptb
1719
from pytensor.tensor import math as ptm
20+
from pytensor.tensor.basic import diagonal
1821
from pytensor.tensor.blockwise import Blockwise
1922
from pytensor.tensor.nlinalg import kron, matrix_dot
2023
from pytensor.tensor.shape import reshape
@@ -260,10 +263,10 @@ def make_node(self, A, b):
260263
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
261264

262265
# Infer dtype by solving the most simple case with 1x1 matrices
263-
inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)]
264-
out_arr = [[None]]
265-
self.perform(None, inp_arr, out_arr)
266-
o_dtype = out_arr[0][0].dtype
266+
o_dtype = scipy_linalg.solve(
267+
np.eye(1).astype(A.dtype),
268+
np.eye(1).astype(b.dtype),
269+
).dtype
267270
x = tensor(dtype=o_dtype, shape=b.type.shape)
268271
return Apply(self, [A, b], [x])
269272

@@ -315,7 +318,7 @@ def _default_b_ndim(b, b_ndim):
315318

316319
b = as_tensor_variable(b)
317320
if b_ndim is None:
318-
return min(b.ndim, 2) # By default assume the core case is a matrix
321+
return min(b.ndim, 2) # By default, assume the core case is a matrix
319322

320323

321324
class CholeskySolve(SolveBase):
@@ -332,6 +335,19 @@ def __init__(self, **kwargs):
332335
kwargs.setdefault("lower", True)
333336
super().__init__(**kwargs)
334337

338+
def make_node(self, *inputs):
339+
# Allow base class to do input validation
340+
super_apply = super().make_node(*inputs)
341+
A, b = super_apply.inputs
342+
[super_out] = super_apply.outputs
343+
# The dtype of chol_solve does not match solve, which the base class checks
344+
dtype = scipy_linalg.cho_solve(
345+
np.eye(1).astype(A.dtype),
346+
np.eye(1).astype(b.dtype),
347+
).dtype
348+
out = tensor(dtype=dtype, shape=super_out.type.shape)
349+
return Apply(self, [A, b], [out])
350+
335351
def perform(self, node, inputs, output_storage):
336352
C, b = inputs
337353
rval = scipy_linalg.cho_solve(
@@ -499,8 +515,32 @@ class Solve(SolveBase):
499515
)
500516

501517
def __init__(self, *, assume_a="gen", **kwargs):
502-
if assume_a not in ("gen", "sym", "her", "pos"):
503-
raise ValueError(f"{assume_a} is not a recognized matrix structure")
518+
# Triangular and diagonal are handled outside of Solve
519+
valid_options = ["gen", "sym", "her", "pos", "tridiagonal", "banded"]
520+
521+
assume_a = assume_a.lower()
522+
# We use the old names as the different dispatches are more likely to support them
523+
if assume_a == "general":
524+
assume_a = "gen"
525+
elif assume_a == "symmetric":
526+
assume_a = "sym"
527+
elif assume_a == "hermitian":
528+
assume_a = "her"
529+
elif assume_a == "positive definite":
530+
assume_a = "pos"
531+
if assume_a not in valid_options:
532+
raise ValueError(
533+
f"Invalid assume_a: {assume_a}. It must be one of {valid_options}"
534+
)
535+
536+
if assume_a in ("tridiagonal", "banded") and parse_version(
537+
scipy.__version__
538+
) < parse_version("1.15.0"):
539+
warnings.warn(
540+
f"assume_a={assume_a} requires scipy>=1.5.0. Defaulting to assume_a='gen'.",
541+
UserWarning,
542+
)
543+
assume_a = "gen"
504544

505545
super().__init__(**kwargs)
506546
self.assume_a = assume_a
@@ -536,10 +576,12 @@ def solve(
536576
a,
537577
b,
538578
*,
539-
assume_a="gen",
540-
lower=False,
541-
transposed=False,
542-
check_finite=True,
579+
lower: bool = False,
580+
overwrite_a: bool = False,
581+
overwrite_b: bool = False,
582+
check_finite: bool = True,
583+
assume_a: str = "gen",
584+
transposed: bool = False,
543585
b_ndim: int | None = None,
544586
):
545587
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
@@ -548,14 +590,19 @@ def solve(
548590
corresponding string to ``assume_a`` key chooses the dedicated solver.
549591
The available options are
550592
551-
=================== ========
552-
generic matrix 'gen'
553-
symmetric 'sym'
554-
hermitian 'her'
555-
positive definite 'pos'
556-
=================== ========
593+
=================== ================================
594+
diagonal 'diagonal'
595+
tridiagonal 'tridiagonal'
596+
banded 'banded'
597+
upper triangular 'upper triangular'
598+
lower triangular 'lower triangular'
599+
symmetric 'symmetric' (or 'sym')
600+
hermitian 'hermitian' (or 'her')
601+
positive definite 'positive definite' (or 'pos')
602+
general 'general' (or 'gen')
603+
=================== ================================
557604
558-
If omitted, ``'gen'`` is the default structure.
605+
If omitted, ``'general'`` is the default structure.
559606
560607
The datatype of the arrays define which solver is called regardless
561608
of the values. In other words, even when the complex array entries have
@@ -568,23 +615,52 @@ def solve(
568615
Square input data
569616
b : (..., N, NRHS) array_like
570617
Input data for the right hand side.
571-
lower : bool, optional
572-
If True, use only the data contained in the lower triangle of `a`. Default
573-
is to use upper triangle. (ignored for ``'gen'``)
574-
transposed: bool, optional
575-
If True, solves the system A^T x = b. Default is False.
618+
lower : bool, default False
619+
Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
620+
If True, the calculation uses only the data in the lower triangle of `a`;
621+
entries above the diagonal are ignored. If False (default), the
622+
calculation uses only the data in the upper triangle of `a`; entries
623+
below the diagonal are ignored.
624+
overwrite_a : bool
625+
Ignored argument. PyTensor will perform the operation in-place if possible.
626+
overwrite_b : bool
627+
Ignored argument. PyTensor will perform the operation in-place if possible.
576628
check_finite : bool, optional
577629
Whether to check that the input matrices contain only finite numbers.
578630
Disabling may give a performance gain, but may result in problems
579631
(crashes, non-termination) if the inputs do contain infinities or NaNs.
580632
assume_a : str, optional
581633
Valid entries are explained above.
634+
transposed: bool, default False
635+
If True, solves the system A^T x = b. Default is False.
582636
b_ndim : int
583637
Whether the core case of b is a vector (1) or matrix (2).
584638
This will influence how batched dimensions are interpreted.
639+
By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1.
585640
"""
641+
assume_a = assume_a.lower()
642+
643+
if assume_a in ("lower triangular", "upper triangular"):
644+
lower = "lower" in assume_a
645+
return solve_triangular(
646+
a,
647+
b,
648+
lower=lower,
649+
trans=transposed,
650+
check_finite=check_finite,
651+
b_ndim=b_ndim,
652+
)
653+
586654
b_ndim = _default_b_ndim(b, b_ndim)
587655

656+
if assume_a == "diagonal":
657+
a_diagonal = diagonal(a, axis1=-2, axis2=-1)
658+
b_transposed = b[None, :] if b_ndim == 1 else b.mT
659+
x = (b_transposed / pt.expand_dims(a_diagonal, -2)).mT
660+
if b_ndim == 1:
661+
x = x.squeeze(-1)
662+
return x
663+
588664
if transposed:
589665
a = a.mT
590666
lower = not lower

0 commit comments

Comments
 (0)