Skip to content

Commit c51738a

Browse files
Add jax rewrite to eliminate BilinearSolveDiscreteLyapunov
1 parent 1182791 commit c51738a

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
from pytensor import Variable
66
from pytensor import tensor as pt
7+
from pytensor.compile import optdb
78
from pytensor.graph import Apply, FunctionGraph
89
from pytensor.graph.rewriting.basic import (
910
copy_stack_trace,
11+
in2out,
1012
node_rewriter,
1113
)
1214
from pytensor.scalar.basic import Mul
@@ -41,10 +43,12 @@
4143
register_stabilize,
4244
)
4345
from pytensor.tensor.slinalg import (
46+
BilinearSolveDiscreteLyapunov,
4447
BlockDiagonal,
4548
Cholesky,
4649
Solve,
4750
SolveBase,
51+
_direct_solve_discrete_lyapunov,
4852
block_diag,
4953
cholesky,
5054
solve,
@@ -966,3 +970,30 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
966970
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
967971

968972
return [eye_input * (non_eye_input**0.5)]
973+
974+
975+
@node_rewriter([Blockwise])
976+
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
977+
"""
978+
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
979+
"""
980+
981+
# Check if the op is BilinearSolveDiscreteLyapunov
982+
if not isinstance(node.op.core_op, BilinearSolveDiscreteLyapunov):
983+
return None
984+
985+
# Extract the inputs
986+
(A, B) = node.inputs
987+
988+
# Compute the result
989+
result = _direct_solve_discrete_lyapunov(A, B)
990+
991+
return [result]
992+
993+
994+
optdb.register(
995+
"jax_bilinaer_lyapunov_to_direct",
996+
in2out(jax_bilinaer_lyapunov_to_direct),
997+
"jax",
998+
position=100,
999+
)

tests/link/jax/test_slinalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Literal
2+
13
import numpy as np
24
import pytest
35

@@ -194,3 +196,19 @@ def test_jax_eigvalsh(lower):
194196
None,
195197
],
196198
)
199+
200+
201+
@pytest.mark.parametrize("method", ["direct", "bilinear"])
202+
def test_jax_solve_discrete_lyapunov(method: Literal["direct", "bilinear"]):
203+
A = matrix("A")
204+
B = matrix("B")
205+
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
206+
out_fg = FunctionGraph([A, B], [out])
207+
208+
compare_jax_and_py(
209+
out_fg,
210+
[
211+
np.random.normal(size=(5, 5)).astype(config.floatX),
212+
np.random.normal(size=(5, 5)).astype(config.floatX),
213+
],
214+
)

0 commit comments

Comments
 (0)