Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node):
value, *shape = inp.owner.inputs

# Check what to do with the value of the Alloc
squeezed_value = _squeeze_left(value, batch_ndim)
missing_ndim = len(shape) - value.type.ndim
missing_ndim = inp.type.ndim - value.type.ndim
squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim))
if (
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
!= inp.type.broadcastable[batch_ndim:]
Expand Down
23 changes: 23 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from pytensor import Variable
from pytensor import tensor as pt
from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
in2out,
node_rewriter,
)
from pytensor.scalar.basic import Mul
Expand Down Expand Up @@ -45,9 +47,11 @@
Cholesky,
Solve,
SolveBase,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
solve,
solve_discrete_lyapunov,
solve_triangular,
)

Expand Down Expand Up @@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
non_eye_input = pt.shape_padaxis(non_eye_input, -2)

return [eye_input * (non_eye_input**0.5)]


@node_rewriter([_bilinear_solve_discrete_lyapunov])
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
"""
A, B = (cast(TensorVariable, x) for x in node.inputs)
result = solve_discrete_lyapunov(A, B, method="direct")

return [result]


optdb.register(
"jax_bilinaer_lyapunov_to_direct",
in2out(jax_bilinaer_lyapunov_to_direct),
"jax",
position=0.9, # Run before canonicalization
)
194 changes: 127 additions & 67 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing
import warnings
from functools import reduce
from typing import TYPE_CHECKING, Literal, cast
from typing import Literal, cast

import numpy as np
import scipy.linalg
Expand All @@ -11,7 +11,7 @@
import pytensor.tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.blockwise import Blockwise
Expand All @@ -21,9 +21,6 @@
from pytensor.tensor.variable import TensorVariable


if TYPE_CHECKING:
from pytensor.tensor import TensorLike

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -777,7 +774,16 @@


class SolveContinuousLyapunov(Op):
"""
Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X.

Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved
efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
scipy.linalg.solve_continuous_lyapunov
"""

__props__ = ()
gufunc_signature = "(m,m),(m,m)->(m,m)"

def make_node(self, A, B):
A = as_tensor_variable(A)
Expand All @@ -792,7 +798,8 @@
(A, B) = inputs
X = output_storage[0]

X[0] = scipy.linalg.solve_continuous_lyapunov(A, B)
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand All @@ -813,7 +820,41 @@
return [A_bar, Q_bar]


_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())


def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.

Parameters
----------
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorLike
Square matrix of shape ``N x N``.

Returns
-------
X: TensorVariable
Square matrix of shape ``N x N``

"""

return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))


class BilinearSolveDiscreteLyapunov(Op):
"""
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.

The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
docstring for scipy.linalg.solve_discrete_lyapunov
"""

gufunc_signature = "(m,m),(m,m)->(m,m)"

def make_node(self, A, B):
A = as_tensor_variable(A)
B = as_tensor_variable(B)
Expand All @@ -827,7 +868,10 @@
(A, B) = inputs
X = output_storage[0]

X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear")
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
out_dtype
)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand All @@ -849,83 +893,83 @@
return [A_bar, Q_bar]


_solve_continuous_lyapunov = SolveContinuousLyapunov()
_solve_bilinear_direct_lyapunov = cast(typing.Callable, BilinearSolveDiscreteLyapunov())
_bilinear_solve_discrete_lyapunov = Blockwise(BilinearSolveDiscreteLyapunov())


def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
A_ = as_tensor_variable(A)
Q_ = as_tensor_variable(Q)
def _direct_solve_discrete_lyapunov(
A: TensorVariable, Q: TensorVariable
) -> TensorVariable:
r"""
Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
Neudecker.

This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
"""

if "complex" in A_.type.dtype:
AA = kron(A_, A_.conj())
if A.type.dtype.startswith("complex"):
AxA = kron(A, A.conj())
else:
AA = kron(A_, A_)
AxA = kron(A, A)

eye = pt.eye(AxA.shape[-1])

X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
return cast(TensorVariable, reshape(X, Q_.shape))
vec_Q = Q.ravel()
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)

return cast(TensorVariable, reshape(vec_X, A.shape))


def solve_discrete_lyapunov(
A: "TensorLike", Q: "TensorLike", method: Literal["direct", "bilinear"] = "direct"
A: TensorLike,
Q: TensorLike,
method: Literal["direct", "bilinear"] = "bilinear",
) -> TensorVariable:
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.

Parameters
----------
A
Square matrix of shape N x N; must have the same shape as Q
Q
Square matrix of shape N x N; must have the same shape as A
method
Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"``
solves the problem directly via matrix inversion. This has a pure
PyTensor implementation and can thus be cross-compiled to supported
backends, and should be preferred when ``N`` is not large. The direct
method scales poorly with the size of ``N``, and the bilinear can be
A: TensorLike
Square matrix of shape N x N
Q: TensorLike
Square matrix of shape N x N
method: str, one of ``"direct"`` or ``"bilinear"``
Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
used in these cases.

Returns
-------
Square matrix of shape ``N x N``, representing the solution to the
Lyapunov equation
X: TensorVariable
Square matrix of shape ``N x N``. Solution to the Lyapunov equation

"""
if method not in ["direct", "bilinear"]:
raise ValueError(
f'Parameter "method" must be one of "direct" or "bilinear", found {method}'
)

if method == "direct":
return _direct_solve_discrete_lyapunov(A, Q)
if method == "bilinear":
return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))


def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
"""Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.

Parameters
----------
A
Square matrix of shape ``N x N``; must have the same shape as `Q`.
Q
Square matrix of shape ``N x N``; must have the same shape as `A`.
A = as_tensor_variable(A)
Q = as_tensor_variable(Q)

Returns
-------
Square matrix of shape ``N x N``, representing the solution to the
Lyapunov equation
if method == "direct":
signature = BilinearSolveDiscreteLyapunov.gufunc_signature
X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
return cast(TensorVariable, X)

"""
elif method == "bilinear":
return cast(TensorVariable, _bilinear_solve_discrete_lyapunov(A, Q))

return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
else:
raise ValueError(f"Unknown method {method}")

Check warning on line 965 in pytensor/tensor/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/slinalg.py#L965

Added line #L965 was not covered by tests


class SolveDiscreteARE(pt.Op):
class SolveDiscreteARE(Op):
__props__ = ("enforce_Q_symmetric",)
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"

def __init__(self, enforce_Q_symmetric=False):
def __init__(self, enforce_Q_symmetric: bool = False):
self.enforce_Q_symmetric = enforce_Q_symmetric

def make_node(self, A, B, Q, R):
Expand All @@ -946,9 +990,8 @@
if self.enforce_Q_symmetric:
Q = 0.5 * (Q + Q.T)

X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(
node.outputs[0].type.dtype
)
out_dtype = node.outputs[0].type.dtype
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
Expand All @@ -960,14 +1003,16 @@
(dX,) = output_grads
X = self(A, B, Q, R)

K_inner = R + pt.linalg.matrix_dot(B.T, X, B)
K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0]))
K = matrix_dot(K_inner_inv, B.T, X, A)
K_inner = R + matrix_dot(B.T, X, B)

# K_inner is guaranteed to be symmetric, because X and R are symmetric
K_inner_inv_BT = solve(K_inner, B.T, assume_a="sym")
K = matrix_dot(K_inner_inv_BT, X, A)

A_tilde = A - B.dot(K)

dX_symm = 0.5 * (dX + dX.T)
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype)
S = solve_discrete_lyapunov(A_tilde, dX_symm)

A_bar = 2 * matrix_dot(X, A_tilde, S)
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
Expand All @@ -977,30 +1022,45 @@
return [A_bar, B_bar, Q_bar, R_bar]


def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
def solve_discrete_are(
A: TensorLike,
B: TensorLike,
Q: TensorLike,
R: TensorLike,
enforce_Q_symmetric: bool = False,
) -> TensorVariable:
"""
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.

Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
steady-state covariance of the Kalman Filter.

Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
solution. This stable solution, if it exists, will be returned by this function.

Parameters
----------
A: ArrayLike
A: TensorLike
Square matrix of shape M x M
B: ArrayLike
B: TensorLike
Square matrix of shape M x M
Q: ArrayLike
Q: TensorLike
Symmetric square matrix of shape M x M
R: ArrayLike
R: TensorLike
Square matrix of shape N x N
enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry

Returns
-------
X: pt.matrix
X: TensorVariable
Square matrix of shape M x M, representing the solution to the DARE
"""

return cast(TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R))
return cast(
TensorVariable, Blockwise(SolveDiscreteARE(enforce_Q_symmetric))(A, B, Q, R)
)


def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
Expand Down
Loading
Loading