File tree Expand file tree Collapse file tree 2 files changed +49
-0
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +49
-0
lines changed Original file line number Diff line number Diff line change 44
55from pytensor import Variable
66from pytensor import tensor as pt
7+ from pytensor .compile import optdb
78from pytensor .graph import Apply , FunctionGraph
89from pytensor .graph .rewriting .basic import (
910 copy_stack_trace ,
11+ in2out ,
1012 node_rewriter ,
1113)
1214from pytensor .scalar .basic import Mul
4143 register_stabilize ,
4244)
4345from pytensor .tensor .slinalg import (
46+ BilinearSolveDiscreteLyapunov ,
4447 BlockDiagonal ,
4548 Cholesky ,
4649 Solve ,
4750 SolveBase ,
51+ _direct_solve_discrete_lyapunov ,
4852 block_diag ,
4953 cholesky ,
5054 solve ,
@@ -966,3 +970,30 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
966970 non_eye_input = pt .shape_padaxis (non_eye_input , - 2 )
967971
968972 return [eye_input * (non_eye_input ** 0.5 )]
973+
974+
975+ @node_rewriter ([Blockwise ])
976+ def jax_bilinaer_lyapunov_to_direct (fgraph : FunctionGraph , node : Apply ):
977+ """
978+ Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
979+ """
980+
981+ # Check if the op is BilinearSolveDiscreteLyapunov
982+ if not isinstance (node .op .core_op , BilinearSolveDiscreteLyapunov ):
983+ return None
984+
985+ # Extract the inputs
986+ (A , B ) = node .inputs
987+
988+ # Compute the result
989+ result = _direct_solve_discrete_lyapunov (A , B )
990+
991+ return [result ]
992+
993+
994+ optdb .register (
995+ "jax_bilinaer_lyapunov_to_direct" ,
996+ in2out (jax_bilinaer_lyapunov_to_direct ),
997+ "jax" ,
998+ position = 100 ,
999+ )
Original file line number Diff line number Diff line change 1+ from typing import Literal
2+
13import numpy as np
24import pytest
35
@@ -194,3 +196,19 @@ def test_jax_eigvalsh(lower):
194196 None ,
195197 ],
196198 )
199+
200+
201+ @pytest .mark .parametrize ("method" , ["direct" , "bilinear" ])
202+ def test_jax_solve_discrete_lyapunov (method : Literal ["direct" , "bilinear" ]):
203+ A = matrix ("A" )
204+ B = matrix ("B" )
205+ out = pt_slinalg .solve_discrete_lyapunov (A , B , method = method )
206+ out_fg = FunctionGraph ([A , B ], [out ])
207+
208+ compare_jax_and_py (
209+ out_fg ,
210+ [
211+ np .random .normal (size = (5 , 5 )).astype (config .floatX ),
212+ np .random .normal (size = (5 , 5 )).astype (config .floatX ),
213+ ],
214+ )
You can’t perform that action at this time.
0 commit comments