Skip to content

Commit 89d5fd0

Browse files
Set dtype of Op outputs
1 parent cb809c1 commit 89d5fd0

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

pytensor/tensor/slinalg.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,9 @@ def make_node(self, A, B):
797797
def perform(self, node, inputs, output_storage):
798798
(A, B) = inputs
799799
X = output_storage[0]
800-
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B)
800+
801+
out_dtype = node.outputs[0].type.dtype
802+
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
801803

802804
def infer_shape(self, fgraph, node, shapes):
803805
return [shapes[0]]
@@ -866,7 +868,10 @@ def perform(self, node, inputs, output_storage):
866868
(A, B) = inputs
867869
X = output_storage[0]
868870

869-
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear")
871+
out_dtype = node.outputs[0].type.dtype
872+
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
873+
out_dtype
874+
)
870875

871876
def infer_shape(self, fgraph, node, shapes):
872877
return [shapes[0]]
@@ -964,11 +969,8 @@ class SolveDiscreteARE(Op):
964969
__props__ = ("enforce_Q_symmetric",)
965970
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
966971

967-
def __init__(
968-
self, enforce_Q_symmetric: bool = False, use_bilinear_lyapunov: bool = True
969-
):
972+
def __init__(self, enforce_Q_symmetric: bool = False):
970973
self.enforce_Q_symmetric = enforce_Q_symmetric
971-
self.use_bilinear_lyapunov = use_bilinear_lyapunov
972974

973975
def make_node(self, A, B, Q, R):
974976
A = as_tensor_variable(A)
@@ -988,7 +990,8 @@ def perform(self, node, inputs, output_storage):
988990
if self.enforce_Q_symmetric:
989991
Q = 0.5 * (Q + Q.T)
990992

991-
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R)
993+
out_dtype = node.outputs[0].type.dtype
994+
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
992995

993996
def infer_shape(self, fgraph, node, shapes):
994997
return [shapes[0]]
@@ -1000,16 +1003,16 @@ def grad(self, inputs, output_grads):
10001003
(dX,) = output_grads
10011004
X = self(A, B, Q, R)
10021005

1003-
K_inner = R + pt.linalg.matrix_dot(B.T, X, B)
1006+
K_inner = R + matrix_dot(B.T, X, B)
10041007

10051008
# K_inner is guaranteed to be symmetric, because X and R are symmetric
1006-
K_inner_inv_BT = pt.linalg.solve(K_inner, B.T, assume_a="sym")
1009+
K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym")
10071010
K = matrix_dot(K_inner_inv_BT, X, A)
10081011

10091012
A_tilde = A - B.dot(K)
10101013

10111014
dX_symm = 0.5 * (dX + dX.T)
1012-
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype)
1015+
S = solve_discrete_lyapunov(A_tilde, dX_symm)
10131016

10141017
A_bar = 2 * matrix_dot(X, A_tilde, S)
10151018
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)

0 commit comments

Comments
 (0)