Skip to content

Commit 6a6e54a

Browse files
Infer dtype from node.outputs.type.dtype
1 parent 1c222c3 commit 6a6e54a

File tree

4 files changed

+32
-20
lines changed

4 files changed

+32
-20
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343
register_stabilize,
4444
)
4545
from pytensor.tensor.slinalg import (
46+
BilinearSolveDiscreteLyapunov,
4647
BlockDiagonal,
4748
Cholesky,
4849
Solve,
4950
SolveBase,
50-
_solve_bilinear_discrete_lyapunov,
5151
block_diag,
5252
cholesky,
5353
solve,
@@ -972,11 +972,14 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
972972
return [eye_input * (non_eye_input**0.5)]
973973

974974

975-
@node_rewriter([_solve_bilinear_discrete_lyapunov]) # type: ignore
975+
@node_rewriter([Blockwise]) # type: ignore
976976
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
977977
"""
978978
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
979979
"""
980+
if not isinstance(node.op.core_op, BilinearSolveDiscreteLyapunov):
981+
return None
982+
980983
A, B = (cast(TensorVariable, x) for x in node.inputs)
981984
result = solve_discrete_lyapunov(A, B, method="direct")
982985

pytensor/tensor/slinalg.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,8 @@ def make_node(self, A, B):
797797
def perform(self, node, inputs, output_storage):
798798
(A, B) = inputs
799799
X = output_storage[0]
800-
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
800+
out_dtype = node.outputs[0].type.dtype
801+
801802
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
802803

803804
def infer_shape(self, fgraph, node, shapes):
@@ -843,9 +844,9 @@ def perform(self, node, inputs, output_storage):
843844
(A, B) = inputs
844845
X = output_storage[0]
845846

846-
dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
847+
out_dtype = node.outputs[0].type.dtype
847848
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
848-
dtype
849+
out_dtype
849850
)
850851

851852
def infer_shape(self, fgraph, node, shapes):
@@ -868,12 +869,6 @@ def grad(self, inputs, output_grads):
868869
return [A_bar, Q_bar]
869870

870871

871-
_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())
872-
_solve_bilinear_discrete_lyapunov = cast(
873-
typing.Callable, Blockwise(BilinearSolveDiscreteLyapunov())
874-
)
875-
876-
877872
def _direct_solve_discrete_lyapunov(
878873
A: TensorVariable, Q: TensorVariable
879874
) -> TensorVariable:
@@ -937,7 +932,7 @@ def solve_discrete_lyapunov(
937932
return cast(TensorVariable, X)
938933

939934
elif method == "bilinear":
940-
return cast(TensorVariable, _solve_bilinear_discrete_lyapunov(A, Q))
935+
return cast(TensorVariable, Blockwise(BilinearSolveDiscreteLyapunov())(A, Q))
941936

942937
else:
943938
raise ValueError(f"Unknown method {method}")
@@ -961,10 +956,10 @@ def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
961956
962957
"""
963958

964-
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
959+
return cast(TensorVariable, Blockwise(SolveContinuousLyapunov())(A, Q))
965960

966961

967-
class SolveDiscreteARE(pt.Op):
962+
class SolveDiscreteARE(Op):
968963
__props__ = ("enforce_Q_symmetric",)
969964
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
970965

tests/link/jax/test_slinalg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
from typing import Literal
23

34
import numpy as np
@@ -208,11 +209,13 @@ def test_jax_solve_discrete_lyapunov(
208209
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
209210
out_fg = FunctionGraph([A, B], [out])
210211

212+
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
211213
compare_jax_and_py(
212214
out_fg,
213215
[
214216
np.random.normal(size=shape).astype(config.floatX),
215217
np.random.normal(size=shape).astype(config.floatX),
216218
],
217219
jax_mode="JAX",
220+
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
218221
)

tests/tensor/test_slinalg.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,8 @@ def test_solve_discrete_lyapunov(
537537
precision = int(dtype[-2:]) # 64 or 32
538538
dtype = f"complex{int(2 * precision)}"
539539

540-
A1, A2 = rng.normal(size=(2, *shape)).astype(dtype)
541-
Q1, Q2 = rng.normal(size=(2, *shape)).astype(dtype)
540+
A1, A2 = rng.normal(size=(2, *shape))
541+
Q1, Q2 = rng.normal(size=(2, *shape))
542542

543543
if use_complex:
544544
A = A1 + 1j * A2
@@ -547,6 +547,8 @@ def test_solve_discrete_lyapunov(
547547
A = A1
548548
Q = Q1
549549

550+
A, Q = A.astype(dtype), Q.astype(dtype)
551+
550552
a = pt.tensor(name="a", shape=shape, dtype=dtype)
551553
q = pt.tensor(name="q", shape=shape, dtype=dtype)
552554

@@ -585,15 +587,20 @@ def test_solve_discrete_lyapunov_gradient(
585587
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
586588
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
587589
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool):
590+
dtype = config.floatX
591+
if use_complex and dtype == "float32":
592+
pytest.skip(
593+
"Not enough precision in complex64 to do schur decomposition "
594+
"(ill-conditioned matrix errors arise)"
595+
)
588596
rng = np.random.default_rng(utt.fetch_seed())
589597

590-
dtype = config.floatX
591598
if use_complex:
592599
precision = int(dtype[-2:]) # 64 or 32
593600
dtype = f"complex{int(2 * precision)}"
594601

595-
A1, A2 = rng.normal(size=(2, *shape)).astype(dtype)
596-
Q1, Q2 = rng.normal(size=(2, *shape)).astype(dtype)
602+
A1, A2 = rng.normal(size=(2, *shape))
603+
Q1, Q2 = rng.normal(size=(2, *shape))
597604

598605
if use_complex:
599606
A = A1 + 1j * A2
@@ -602,9 +609,13 @@ def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool):
602609
A = A1
603610
Q = Q1
604611

612+
A, Q = A.astype(dtype), Q.astype(dtype)
613+
605614
a = pt.tensor(name="a", shape=shape, dtype=dtype)
606615
q = pt.tensor(name="q", shape=shape, dtype=dtype)
607-
f = function([a, q], [solve_continuous_lyapunov(a, q)])
616+
x = solve_continuous_lyapunov(a, q)
617+
618+
f = function([a, q], x)
608619

609620
X = f(A, Q)
610621

0 commit comments

Comments
 (0)