Skip to content

Commit ec7bcce

Browse files
Use pt.vectorize on base solve_discrete_lyapunov case
1 parent d8fb574 commit ec7bcce

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

pytensor/tensor/slinalg.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -853,26 +853,20 @@ def grad(self, inputs, output_grads):
853853

854854

855855
_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())
856-
_solve_bilinear_direct_lyapunov = cast(
856+
_solve_bilinear_discrete_lyapunov = cast(
857857
typing.Callable, Blockwise(BilinearSolveDiscreteLyapunov())
858858
)
859859

860860

861-
def _direct_solve_discrete_lyapunov(
862-
A: TensorVariable, Q: TensorVariable
863-
) -> TensorVariable:
864-
# By default kron acts on tensors, but we need a vectorized version over matrices for this function
865-
vec_kron = pt.vectorize(kron, "(m,n),(o,p)->(q,r)")
866-
861+
def _direct_solve_discrete_lyapunov(A, Q) -> TensorVariable:
867862
if A.type.dtype.startswith("complex"):
868-
AxA = vec_kron(A, A.conj())
863+
AxA = kron(A, A.conj())
869864
else:
870-
AxA = vec_kron(A, A)
865+
AxA = kron(A, A)
871866

872867
eye = pt.eye(AxA.shape[-1])
873-
q_shape = pt.concatenate([Q.shape[:-2], [-1]])
874868

875-
vec_Q = Q.reshape(q_shape)
869+
vec_Q = Q.ravel()
876870
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)
877871

878872
return cast(TensorVariable, reshape(vec_X, A.shape))
@@ -912,10 +906,11 @@ def solve_discrete_lyapunov(
912906
Q = as_tensor_variable(Q)
913907

914908
if method == "direct":
915-
return _direct_solve_discrete_lyapunov(A, Q)
909+
signature = BilinearSolveDiscreteLyapunov.gufunc_signature
910+
return pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
916911

917912
elif method == "bilinear":
918-
return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
913+
return cast(TensorVariable, _solve_bilinear_discrete_lyapunov(A, Q))
919914

920915
else:
921916
raise ValueError(f"Unknown method {method}")

0 commit comments

Comments
 (0)