Skip to content

Commit cb809c1

Browse files
Don't manually set dtype of output
Revert change to `_solve_discrete_lyapunov`
1 parent fb35d92 commit cb809c1

File tree

2 files changed

+33
-37
lines changed

2 files changed

+33
-37
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343
register_stabilize,
4444
)
4545
from pytensor.tensor.slinalg import (
46-
BilinearSolveDiscreteLyapunov,
4746
BlockDiagonal,
4847
Cholesky,
4948
Solve,
5049
SolveBase,
50+
_bilinear_solve_discrete_lyapunov,
5151
block_diag,
5252
cholesky,
5353
solve,
@@ -972,14 +972,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
972972
return [eye_input * (non_eye_input**0.5)]
973973

974974

975-
@node_rewriter([Blockwise])
975+
@node_rewriter([_bilinear_solve_discrete_lyapunov])
976976
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
977977
"""
978978
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
979979
"""
980-
if not isinstance(node.op.core_op, BilinearSolveDiscreteLyapunov):
981-
return None
982-
983980
A, B = (cast(TensorVariable, x) for x in node.inputs)
984981
result = solve_discrete_lyapunov(A, B, method="direct")
985982

pytensor/tensor/slinalg.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -797,9 +797,7 @@ def make_node(self, A, B):
797797
def perform(self, node, inputs, output_storage):
798798
(A, B) = inputs
799799
X = output_storage[0]
800-
out_dtype = node.outputs[0].type.dtype
801-
802-
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
800+
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B)
803801

804802
def infer_shape(self, fgraph, node, shapes):
805803
return [shapes[0]]
@@ -820,6 +818,30 @@ def grad(self, inputs, output_grads):
820818
return [A_bar, Q_bar]
821819

822820

821+
_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())
822+
823+
824+
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
825+
"""
826+
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
827+
828+
Parameters
829+
----------
830+
A: TensorLike
831+
Square matrix of shape ``N x N``.
832+
Q: TensorLike
833+
Square matrix of shape ``N x N``.
834+
835+
Returns
836+
-------
837+
X: TensorVariable
838+
Square matrix of shape ``N x N``
839+
840+
"""
841+
842+
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
843+
844+
823845
class BilinearSolveDiscreteLyapunov(Op):
824846
"""
825847
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
@@ -844,10 +866,7 @@ def perform(self, node, inputs, output_storage):
844866
(A, B) = inputs
845867
X = output_storage[0]
846868

847-
out_dtype = node.outputs[0].type.dtype
848-
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
849-
out_dtype
850-
)
869+
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear")
851870

852871
def infer_shape(self, fgraph, node, shapes):
853872
return [shapes[0]]
@@ -869,6 +888,9 @@ def grad(self, inputs, output_grads):
869888
return [A_bar, Q_bar]
870889

871890

891+
_bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov())
892+
893+
872894
def _direct_solve_discrete_lyapunov(
873895
A: TensorVariable, Q: TensorVariable
874896
) -> TensorVariable:
@@ -932,33 +954,12 @@ def solve_discrete_lyapunov(
932954
return cast(TensorVariable, X)
933955

934956
elif method == "bilinear":
935-
return cast(TensorVariable, Blockwise(BilinearSolveDiscreteLyapunov())(A, Q))
957+
return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q))
936958

937959
else:
938960
raise ValueError(f"Unknown method {method}")
939961

940962

941-
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
942-
"""
943-
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
944-
945-
Parameters
946-
----------
947-
A: TensorLike
948-
Square matrix of shape ``N x N``.
949-
Q: TensorLike
950-
Square matrix of shape ``N x N``.
951-
952-
Returns
953-
-------
954-
X: TensorVariable
955-
Square matrix of shape ``N x N``
956-
957-
"""
958-
959-
return cast(TensorVariable, Blockwise(SolveContinuousLyapunov())(A, Q))
960-
961-
962963
class SolveDiscreteARE(Op):
963964
__props__ = ("enforce_Q_symmetric",)
964965
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
@@ -987,9 +988,7 @@ def perform(self, node, inputs, output_storage):
987988
if self.enforce_Q_symmetric:
988989
Q = 0.5 * (Q + Q.T)
989990

990-
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(
991-
node.outputs[0].type.dtype
992-
)
991+
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R)
993992

994993
def infer_shape(self, fgraph, node, shapes):
995994
return [shapes[0]]

0 commit comments

Comments
 (0)