Skip to content

Commit 86e6a7e

Browse files
committed
Enable new assume_a in Solve
1 parent 9a4da3f commit 86e6a7e

File tree

4 files changed

+151
-35
lines changed

4 files changed

+151
-35
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import jax
24

35
from pytensor.link.jax.dispatch.basic import jax_funcify
@@ -39,13 +41,29 @@ def cholesky(a, lower=lower):
3941

4042
@jax_funcify.register(Solve)
4143
def jax_funcify_Solve(op, **kwargs):
42-
if op.assume_a != "gen" and op.lower:
43-
lower = True
44+
assume_a = op.assume_a
45+
lower = op.lower
46+
47+
if assume_a == "tridiagonal":
48+
# jax.scipy.solve does not yet support tridiagonal matrices
49+
# But there's a jax.lax.linalg.tridiaonal_solve we can use instead.
50+
def solve(a, b):
51+
dl = jax.numpy.diagonal(a, offset=-1, axis1=-2, axis2=-1)
52+
d = jax.numpy.diagonal(a, offset=0, axis1=-2, axis2=-1)
53+
du = jax.numpy.diagonal(a, offset=1, axis1=-2, axis2=-1)
54+
return jax.lax.linalg.tridiagonal_solve(dl, d, du, b, lower=lower)
55+
4456
else:
45-
lower = False
57+
if assume_a not in ("gen", "sym", "her", "pos"):
58+
warnings.warn(
59+
f"JAX solve does not support assume_a={op.assume_a}. Defaulting to assume_a='gen'.\n"
60+
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her' or 'tridiagonal' to improve performance.",
61+
UserWarning,
62+
)
63+
assume_a = "gen"
4664

47-
def solve(a, b, lower=lower):
48-
return jax.scipy.linalg.solve(a, b, lower=lower)
65+
def solve(a, b):
66+
return jax.scipy.linalg.solve(a, b, lower=lower, assume_a=assume_a)
4967

5068
return solve
5169

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Callable
23

34
import numba
@@ -1070,14 +1071,17 @@ def numba_funcify_Solve(op, node, **kwargs):
10701071
elif assume_a == "sym":
10711072
solve_fn = _solve_symmetric
10721073
elif assume_a == "her":
1073-
raise NotImplementedError(
1074-
'Use assume_a = "sym" for symmetric real matrices. If you need compelx support, '
1075-
"please open an issue on github."
1076-
)
1074+
# We already ruled out complex inputs
1075+
solve_fn = _solve_symmetric
10771076
elif assume_a == "pos":
10781077
solve_fn = _solve_psd
10791078
else:
1080-
raise NotImplementedError(f"Assumption {assume_a} not supported in Numba mode")
1079+
warnings.warn(
1080+
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n"
1081+
f"If appropriate, you may want to set assume_a to one of 'sym', 'pos', or 'her' to improve performance.",
1082+
UserWarning,
1083+
)
1084+
solve_fn = _solve_gen
10811085

10821086
@numba_basic.numba_njit(inline="always")
10831087
def solve(a, b):

pytensor/tensor/slinalg.py

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytensor.tensor import TensorLike, as_tensor_variable
1616
from pytensor.tensor import basic as ptb
1717
from pytensor.tensor import math as ptm
18+
from pytensor.tensor.basic import diagonal
1819
from pytensor.tensor.blockwise import Blockwise
1920
from pytensor.tensor.nlinalg import kron, matrix_dot
2021
from pytensor.tensor.shape import reshape
@@ -260,10 +261,10 @@ def make_node(self, A, b):
260261
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
261262

262263
# Infer dtype by solving the most simple case with 1x1 matrices
263-
inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)]
264-
out_arr = [[None]]
265-
self.perform(None, inp_arr, out_arr)
266-
o_dtype = out_arr[0][0].dtype
264+
o_dtype = scipy_linalg.solve(
265+
np.ones((1, 1), dtype=A.dtype),
266+
np.ones((1,), dtype=b.dtype),
267+
).dtype
267268
x = tensor(dtype=o_dtype, shape=b.type.shape)
268269
return Apply(self, [A, b], [x])
269270

@@ -315,7 +316,7 @@ def _default_b_ndim(b, b_ndim):
315316

316317
b = as_tensor_variable(b)
317318
if b_ndim is None:
318-
return min(b.ndim, 2) # By default assume the core case is a matrix
319+
return min(b.ndim, 2) # By default, assume the core case is a matrix
319320

320321

321322
class CholeskySolve(SolveBase):
@@ -499,8 +500,33 @@ class Solve(SolveBase):
499500
)
500501

501502
def __init__(self, *, assume_a="gen", **kwargs):
502-
if assume_a not in ("gen", "sym", "her", "pos"):
503-
raise ValueError(f"{assume_a} is not a recognized matrix structure")
503+
# Triangular and diagonal are handled outside of Solve
504+
valid_options = ["gen", "sym", "her", "pos", "tridiagonal", "banded"]
505+
506+
assume_a = assume_a.lower()
507+
# We use the old names as the different dispatches are more likely to support them
508+
long_to_short = {
509+
"general": "gen",
510+
"symmetric": "sym",
511+
"hermitian": "her",
512+
"positive definite": "pos",
513+
}
514+
assume_a = long_to_short.get(assume_a, assume_a)
515+
516+
if assume_a not in valid_options:
517+
raise ValueError(
518+
f"Invalid assume_a: {assume_a}. It must be one of {valid_options} or {list(long_to_short.keys())}"
519+
)
520+
521+
if assume_a in ("tridiagonal", "banded"):
522+
from scipy import __version__ as sp_version
523+
524+
if tuple(map(int, sp_version.split(".")[:-1])) < (1, 15):
525+
warnings.warn(
526+
f"assume_a={assume_a} requires scipy>=1.5.0. Defaulting to assume_a='gen'.",
527+
UserWarning,
528+
)
529+
assume_a = "gen"
504530

505531
super().__init__(**kwargs)
506532
self.assume_a = assume_a
@@ -536,10 +562,12 @@ def solve(
536562
a,
537563
b,
538564
*,
539-
assume_a="gen",
540-
lower=False,
541-
transposed=False,
542-
check_finite=True,
565+
lower: bool = False,
566+
overwrite_a: bool = False,
567+
overwrite_b: bool = False,
568+
check_finite: bool = True,
569+
assume_a: str = "gen",
570+
transposed: bool = False,
543571
b_ndim: int | None = None,
544572
):
545573
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
@@ -548,14 +576,19 @@ def solve(
548576
corresponding string to ``assume_a`` key chooses the dedicated solver.
549577
The available options are
550578
551-
=================== ========
552-
generic matrix 'gen'
553-
symmetric 'sym'
554-
hermitian 'her'
555-
positive definite 'pos'
556-
=================== ========
579+
=================== ================================
580+
diagonal 'diagonal'
581+
tridiagonal 'tridiagonal'
582+
banded 'banded'
583+
upper triangular 'upper triangular'
584+
lower triangular 'lower triangular'
585+
symmetric 'symmetric' (or 'sym')
586+
hermitian 'hermitian' (or 'her')
587+
positive definite 'positive definite' (or 'pos')
588+
general 'general' (or 'gen')
589+
=================== ================================
557590
558-
If omitted, ``'gen'`` is the default structure.
591+
If omitted, ``'general'`` is the default structure.
559592
560593
The datatype of the arrays define which solver is called regardless
561594
of the values. In other words, even when the complex array entries have
@@ -568,23 +601,52 @@ def solve(
568601
Square input data
569602
b : (..., N, NRHS) array_like
570603
Input data for the right hand side.
571-
lower : bool, optional
572-
If True, use only the data contained in the lower triangle of `a`. Default
573-
is to use upper triangle. (ignored for ``'gen'``)
574-
transposed: bool, optional
575-
If True, solves the system A^T x = b. Default is False.
604+
lower : bool, default False
605+
Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
606+
If True, the calculation uses only the data in the lower triangle of `a`;
607+
entries above the diagonal are ignored. If False (default), the
608+
calculation uses only the data in the upper triangle of `a`; entries
609+
below the diagonal are ignored.
610+
overwrite_a : bool
611+
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
612+
overwrite_b : bool
613+
Unused by PyTensor. PyTensor will always perform the operation in-place if possible.
576614
check_finite : bool, optional
577615
Whether to check that the input matrices contain only finite numbers.
578616
Disabling may give a performance gain, but may result in problems
579617
(crashes, non-termination) if the inputs do contain infinities or NaNs.
580618
assume_a : str, optional
581619
Valid entries are explained above.
620+
transposed: bool, default False
621+
If True, solves the system A^T x = b. Default is False.
582622
b_ndim : int
583623
Whether the core case of b is a vector (1) or matrix (2).
584624
This will influence how batched dimensions are interpreted.
625+
By default, we assume b_ndim = b.ndim is 2 if b.ndim > 1, else 1.
585626
"""
627+
assume_a = assume_a.lower()
628+
629+
if assume_a in ("lower triangular", "upper triangular"):
630+
lower = "lower" in assume_a
631+
return solve_triangular(
632+
a,
633+
b,
634+
lower=lower,
635+
trans=transposed,
636+
check_finite=check_finite,
637+
b_ndim=b_ndim,
638+
)
639+
586640
b_ndim = _default_b_ndim(b, b_ndim)
587641

642+
if assume_a == "diagonal":
643+
a_diagonal = diagonal(a, axis1=-2, axis2=-1)
644+
b_transposed = b[None, :] if b_ndim == 1 else b.mT
645+
x = (b_transposed / pt.expand_dims(a_diagonal, -2)).mT
646+
if b_ndim == 1:
647+
x = x.squeeze(-1)
648+
return x
649+
588650
if transposed:
589651
a = a.mT
590652
lower = not lower

tests/tensor/test_slinalg.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from pytensor import function, grad
1111
from pytensor import tensor as pt
1212
from pytensor.configdefaults import config
13+
from pytensor.graph.basic import equal_computations
14+
from pytensor.tensor import TensorVariable
1315
from pytensor.tensor.slinalg import (
1416
Cholesky,
1517
CholeskySolve,
@@ -209,8 +211,8 @@ def test__repr__(self):
209211
)
210212

211213

212-
def test_solve_raises_on_invalid_A():
213-
with pytest.raises(ValueError, match="is not a recognized matrix structure"):
214+
def test_solve_raises_on_invalid_assume_a():
215+
with pytest.raises(ValueError, match="Invalid assume_a: test. It must be one of"):
214216
Solve(assume_a="test", b_ndim=2)
215217

216218

@@ -223,6 +225,10 @@ def test_solve_raises_on_invalid_A():
223225
("pos", False, False),
224226
("pos", True, False),
225227
("pos", True, True),
228+
("diagonal", False, False),
229+
("diagonal", False, True),
230+
("tridiagonal", False, False),
231+
("tridiagonal", False, True),
226232
]
227233
solve_test_ids = [
228234
f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
@@ -237,6 +243,16 @@ def A_func(x, assume_a):
237243
return x @ x.T
238244
elif assume_a == "sym":
239245
return (x + x.T) / 2
246+
elif assume_a == "diagonal":
247+
eye_fn = pt.eye if isinstance(x, TensorVariable) else np.eye
248+
return x * eye_fn(x.shape[1])
249+
elif assume_a == "tridiagonal":
250+
eye_fn = pt.eye if isinstance(x, TensorVariable) else np.eye
251+
return x * (
252+
eye_fn(x.shape[1], k=0)
253+
+ eye_fn(x.shape[1], k=-1)
254+
+ eye_fn(x.shape[1], k=1)
255+
)
240256
else:
241257
return x
242258

@@ -344,6 +360,22 @@ def test_solve_gradient(
344360
lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps
345361
)
346362

363+
def test_solve_tringular_indirection(self):
364+
a = pt.matrix("a")
365+
b = pt.vector("b")
366+
367+
indirect = solve(a, b, assume_a="lower triangular")
368+
direct = solve_triangular(a, b, lower=True, trans=False)
369+
assert equal_computations([indirect], [direct])
370+
371+
indirect = solve(a, b, assume_a="upper triangular")
372+
direct = solve_triangular(a, b, lower=False, trans=False)
373+
assert equal_computations([indirect], [direct])
374+
375+
indirect = solve(a, b, assume_a="upper triangular", transposed=True)
376+
direct = solve_triangular(a, b, lower=False, trans=True)
377+
assert equal_computations([indirect], [direct])
378+
347379

348380
class TestSolveTriangular(utt.InferShapeTester):
349381
@staticmethod

0 commit comments

Comments
 (0)