Skip to content

Commit bdb989c

Browse files
Apply JAX rewrite before canonicalization
1 parent ec7bcce commit bdb989c

File tree

1 file changed

+5
-13
lines changed

1 file changed

+5
-13
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@
4343
register_stabilize,
4444
)
4545
from pytensor.tensor.slinalg import (
46-
BilinearSolveDiscreteLyapunov,
4746
BlockDiagonal,
4847
Cholesky,
4948
Solve,
5049
SolveBase,
51-
_direct_solve_discrete_lyapunov,
50+
_solve_bilinear_discrete_lyapunov,
5251
block_diag,
5352
cholesky,
5453
solve,
54+
solve_discrete_lyapunov,
5555
solve_triangular,
5656
)
5757

@@ -972,21 +972,13 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
972972
return [eye_input * (non_eye_input**0.5)]
973973

974974

975-
@node_rewriter([Blockwise])
975+
@node_rewriter([_solve_bilinear_discrete_lyapunov])
976976
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
977977
"""
978978
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
979979
"""
980-
981-
# Check if the op is BilinearSolveDiscreteLyapunov
982-
if not isinstance(node.op.core_op, BilinearSolveDiscreteLyapunov):
983-
return None
984-
985-
# Extract the inputs
986980
A, B = (cast(TensorVariable, x) for x in node.inputs)
987-
988-
# Compute the result
989-
result = _direct_solve_discrete_lyapunov(A, B)
981+
result = solve_discrete_lyapunov(A, B, method="direct")
990982

991983
return [result]
992984

@@ -995,5 +987,5 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
995987
"jax_bilinaer_lyapunov_to_direct",
996988
in2out(jax_bilinaer_lyapunov_to_direct),
997989
"jax",
998-
position=100,
990+
position=0.9, # Run before canonicalization
999991
)

0 commit comments

Comments
 (0)