4343 register_stabilize ,
4444)
4545from pytensor .tensor .slinalg import (
46- BilinearSolveDiscreteLyapunov ,
4746 BlockDiagonal ,
4847 Cholesky ,
4948 Solve ,
5049 SolveBase ,
51- _direct_solve_discrete_lyapunov ,
50+ _solve_bilinear_discrete_lyapunov ,
5251 block_diag ,
5352 cholesky ,
5453 solve ,
54+ solve_discrete_lyapunov ,
5555 solve_triangular ,
5656)
5757
@@ -972,21 +972,13 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
972972 return [eye_input * (non_eye_input ** 0.5 )]
973973
974974
975- @node_rewriter ([Blockwise ])
975+ @node_rewriter ([_solve_bilinear_discrete_lyapunov ])
976976def jax_bilinaer_lyapunov_to_direct (fgraph : FunctionGraph , node : Apply ):
977977 """
978978 Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
979979 """
980-
981- # Check if the op is BilinearSolveDiscreteLyapunov
982- if not isinstance (node .op .core_op , BilinearSolveDiscreteLyapunov ):
983- return None
984-
985- # Extract the inputs
986980 A , B = (cast (TensorVariable , x ) for x in node .inputs )
987-
988- # Compute the result
989- result = _direct_solve_discrete_lyapunov (A , B )
981+ result = solve_discrete_lyapunov (A , B , method = "direct" )
990982
991983 return [result ]
992984
@@ -995,5 +987,5 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
995987 "jax_bilinaer_lyapunov_to_direct" ,
996988 in2out (jax_bilinaer_lyapunov_to_direct ),
997989 "jax" ,
998- position = 100 ,
990+ position = 0.9 , # Run before canonicalization
999991)
0 commit comments