Skip to content

Commit 1182791

Browse files
Blockwise optimal linear control ops
1 parent a377c22 commit 1182791

File tree

2 files changed

+142
-99
lines changed

2 files changed

+142
-99
lines changed

pytensor/tensor/slinalg.py

Lines changed: 84 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,7 @@ def perform(self, node, inputs, outputs):
778778

779779
class SolveContinuousLyapunov(Op):
780780
__props__ = ()
781+
gufunc_signature = "(m,m),(m,m)->(m,m)"
781782

782783
def make_node(self, A, B):
783784
A = as_tensor_variable(A)
@@ -814,6 +815,8 @@ def grad(self, inputs, output_grads):
814815

815816

816817
class BilinearSolveDiscreteLyapunov(Op):
818+
gufunc_signature = "(m,m),(m,m)->(m,m)"
819+
817820
def make_node(self, A, B):
818821
A = as_tensor_variable(A)
819822
B = as_tensor_variable(B)
@@ -849,84 +852,102 @@ def grad(self, inputs, output_grads):
849852
return [A_bar, Q_bar]
850853

851854

852-
_solve_continuous_lyapunov = SolveContinuousLyapunov()
853-
_solve_bilinear_direct_lyapunov = cast(typing.Callable, BilinearSolveDiscreteLyapunov())
855+
_solve_continuous_lyapunov = Blockwise(SolveContinuousLyapunov())
856+
_solve_bilinear_direct_lyapunov = cast(
857+
typing.Callable, Blockwise(BilinearSolveDiscreteLyapunov())
858+
)
854859

855860

856-
def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
857-
A_ = as_tensor_variable(A)
858-
Q_ = as_tensor_variable(Q)
861+
def _direct_solve_discrete_lyapunov(
862+
A: TensorVariable, Q: TensorVariable
863+
) -> TensorVariable:
864+
# By default kron acts on tensors, but we need a vectorized version over matrices for this function
865+
vec_kron = pt.vectorize(kron, "(m,n),(o,p)->(q,r)")
859866

860-
if "complex" in A_.type.dtype:
861-
AA = kron(A_, A_.conj())
867+
if A.type.dtype.startswith("complex"):
868+
AxA = vec_kron(A, A.conj())
862869
else:
863-
AA = kron(A_, A_)
870+
AxA = vec_kron(A, A)
871+
872+
eye = pt.eye(AxA.shape[-1])
873+
q_shape = pt.concatenate([Q.shape[:-2], [-1]])
874+
875+
vec_Q = Q.reshape(q_shape)
876+
vec_X = solve(eye - AxA, vec_Q, b_ndim=1)
864877

865-
X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
866-
return cast(TensorVariable, reshape(X, Q_.shape))
878+
return cast(TensorVariable, reshape(vec_X, A.shape))
867879

868880

869881
def solve_discrete_lyapunov(
870-
A: "TensorLike", Q: "TensorLike", method: Literal["direct", "bilinear"] = "direct"
882+
A: TensorVariable,
883+
Q: TensorVariable,
884+
method: Literal["direct", "bilinear"] = "direct",
871885
) -> TensorVariable:
872886
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
873887
874888
Parameters
875889
----------
876-
A
877-
Square matrix of shape N x N; must have the same shape as Q
878-
Q
879-
Square matrix of shape N x N; must have the same shape as A
880-
method
881-
Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"``
882-
solves the problem directly via matrix inversion. This has a pure
883-
PyTensor implementation and can thus be cross-compiled to supported
884-
backends, and should be preferred when ``N`` is not large. The direct
885-
method scales poorly with the size of ``N``, and the bilinear can be
890+
A: TensorVariable
891+
Square matrix of shape N x N
892+
Q: TensorVariable
893+
Square matrix of shape N x N
894+
method: str, one of ``"direct"`` or ``"bilinear"``
895+
Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
896+
PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
897+
``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
886898
used in these cases.
887899
888900
Returns
889901
-------
890-
Square matrix of shape ``N x N``, representing the solution to the
891-
Lyapunov equation
902+
X: TensorVariable
903+
Square matrix of shape ``N x N``. Solution to the Lyapunov equation
892904
893905
"""
894906
if method not in ["direct", "bilinear"]:
895907
raise ValueError(
896908
f'Parameter "method" must be one of "direct" or "bilinear", found {method}'
897909
)
898910

911+
A = as_tensor_variable(A)
912+
Q = as_tensor_variable(Q)
913+
899914
if method == "direct":
900915
return _direct_solve_discrete_lyapunov(A, Q)
916+
901917
if method == "bilinear":
902918
return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
903919

904920

905-
def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
906-
"""Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
921+
def solve_continuous_lyapunov(A: TensorVariable, Q: TensorVariable) -> TensorVariable:
922+
"""
923+
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
907924
908925
Parameters
909926
----------
910-
A
911-
Square matrix of shape ``N x N``; must have the same shape as `Q`.
912-
Q
913-
Square matrix of shape ``N x N``; must have the same shape as `A`.
927+
A: TensorVariable
928+
Square matrix of shape ``N x N``.
929+
Q: TensorVariable
930+
Square matrix of shape ``N x N``.
914931
915932
Returns
916933
-------
917-
Square matrix of shape ``N x N``, representing the solution to the
918-
Lyapunov equation
934+
X: TensorVariable
935+
Square matrix of shape ``N x N``
919936
920937
"""
921938

922939
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
923940

924941

925942
class SolveDiscreteARE(pt.Op):
926-
__props__ = ("enforce_Q_symmetric",)
943+
__props__ = ("enforce_Q_symmetric", "use_bilinear_lyapunov")
944+
gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
927945

928-
def __init__(self, enforce_Q_symmetric=False):
946+
def __init__(
947+
self, enforce_Q_symmetric: bool = False, use_bilinear_lyapunov: bool = True
948+
):
929949
self.enforce_Q_symmetric = enforce_Q_symmetric
950+
self.use_bilinear_lyapunov = use_bilinear_lyapunov
930951

931952
def make_node(self, A, B, Q, R):
932953
A = as_tensor_variable(A)
@@ -961,13 +982,20 @@ def grad(self, inputs, output_grads):
961982
X = self(A, B, Q, R)
962983

963984
K_inner = R + pt.linalg.matrix_dot(B.T, X, B)
964-
K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0]))
965-
K = matrix_dot(K_inner_inv, B.T, X, A)
985+
986+
# K_inner is guaranteed to be symmetric, because X and R are symmetric
987+
K_inner_inv_BT = pt.linalg.solve(K_inner, B.T, assume_a="sym")
988+
K = matrix_dot(K_inner_inv_BT, X, A)
966989

967990
A_tilde = A - B.dot(K)
968991

969992
dX_symm = 0.5 * (dX + dX.T)
970-
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype)
993+
method: Literal["bilinear", "direct"] = (
994+
"bilinear" if self.use_bilinear_lyapunov else "direct"
995+
)
996+
S = solve_discrete_lyapunov(A_tilde, dX_symm, method=method).astype(
997+
dX.type.dtype
998+
)
971999

9721000
A_bar = 2 * matrix_dot(X, A_tilde, S)
9731001
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
@@ -977,30 +1005,43 @@ def grad(self, inputs, output_grads):
9771005
return [A_bar, B_bar, Q_bar, R_bar]
9781006

9791007

980-
def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
1008+
def solve_discrete_are(
1009+
A: TensorVariable,
1010+
B: TensorVariable,
1011+
Q: TensorVariable,
1012+
R: TensorVariable,
1013+
enforce_Q_symmetric: bool = False,
1014+
use_bilinear_lyapunov: bool = True,
1015+
) -> TensorVariable:
9811016
"""
9821017
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
9831018
9841019
Parameters
9851020
----------
986-
A: ArrayLike
1021+
A: TensorVariable
9871022
Square matrix of shape M x M
988-
B: ArrayLike
1023+
B: TensorVariable
9891024
Square matrix of shape M x M
990-
Q: ArrayLike
1025+
Q: TensorVariable
9911026
Symmetric square matrix of shape M x M
992-
R: ArrayLike
1027+
R: TensorVariable
9931028
Square matrix of shape N x N
9941029
enforce_Q_symmetric: bool
9951030
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
1031+
use_bilinear_lyapunov: bool
1032+
If True, the bilinear method is used to solve a discrete Lyapunov equation when computing the gradients of
1033+
the ARE. If False, the direct method is used instead. See the docstring for ``solve_discrete_lyapunov`` for
1034+
details.
9961035
9971036
Returns
9981037
-------
999-
X: pt.matrix
1038+
X: TensorVariable
10001039
Square matrix of shape M x M, representing the solution to the DARE
10011040
"""
10021041

1003-
return cast(TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R))
1042+
return cast(
1043+
TensorVariable, Blockwise(SolveDiscreteARE(enforce_Q_symmetric))(A, B, Q, R)
1044+
)
10041045

10051046

10061047
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:

tests/tensor/test_slinalg.py

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import itertools
3+
from typing import Literal
34

45
import numpy as np
56
import pytest
@@ -514,105 +515,106 @@ def test_expm_grad_3():
514515
utt.verify_grad(expm, [A], rng=rng)
515516

516517

517-
def test_solve_discrete_lyapunov_via_direct_real():
518-
N = 5
519-
rng = np.random.default_rng(utt.fetch_seed())
520-
a = pt.dmatrix("a")
521-
q = pt.dmatrix("q")
522-
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
523-
524-
A = rng.normal(size=(N, N))
525-
Q = rng.normal(size=(N, N))
518+
def recover_Q(A, X, continuous=True):
519+
if continuous:
520+
return A @ X + X @ A.conj().T
521+
else:
522+
return X - A @ X @ A.conj().T
526523

527-
X = f(A, Q)
528-
assert np.allclose(A @ X @ A.T - X + Q, 0.0)
529524

530-
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
525+
vec_recover_Q = np.vectorize(recover_Q, signature="(m,m),(m,m),()->(m,m)")
531526

532527

528+
@pytest.mark.parametrize("use_complex", [False, True])
529+
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
530+
@pytest.mark.parametrize("method", ["direct", "bilinear"])
533531
@pytest.mark.filterwarnings("ignore::UserWarning")
534-
def test_solve_discrete_lyapunov_via_direct_complex():
535-
# Conj doesn't have C-op; filter the warning.
536-
537-
N = 5
532+
def test_solve_discrete_lyapunov(
533+
use_complex, shape: tuple[int], method: Literal["direct", "bilinear"]
534+
):
538535
rng = np.random.default_rng(utt.fetch_seed())
539-
a = pt.zmatrix()
540-
q = pt.zmatrix()
541-
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
536+
dtype = config.floatX
537+
if use_complex:
538+
precision = int(dtype[-2:]) # 64 or 32
539+
dtype = f"complex{int(2 * precision)}"
542540

543-
A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j
544-
Q = rng.normal(size=(N, N))
545-
X = f(A, Q)
546-
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
541+
a = pt.tensor(name="a", shape=shape, dtype=dtype)
542+
q = pt.tensor(name="q", shape=shape, dtype=dtype)
547543

548-
# TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented.
549-
# utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
544+
f = function([a, q], solve_discrete_lyapunov(a, q, method=method))
550545

551-
552-
def test_solve_discrete_lyapunov_via_bilinear():
553-
N = 5
554-
rng = np.random.default_rng(utt.fetch_seed())
555-
a = pt.dmatrix()
556-
q = pt.dmatrix()
557-
f = function([a, q], [solve_discrete_lyapunov(a, q, method="bilinear")])
558-
559-
A = rng.normal(size=(N, N))
560-
Q = rng.normal(size=(N, N))
546+
A = rng.normal(size=shape)
547+
Q = rng.normal(size=shape)
561548

562549
X = f(A, Q)
550+
Q_recovered = vec_recover_Q(A, X, continuous=False)
551+
np.testing.assert_allclose(Q_recovered, Q)
563552

564-
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
565-
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
553+
utt.verify_grad(
554+
functools.partial(solve_discrete_lyapunov, method=method), pt=[A, Q], rng=rng
555+
)
566556

567557

568-
def test_solve_continuous_lyapunov():
569-
N = 5
558+
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
559+
def test_solve_continuous_lyapunov(shape: tuple[int]):
570560
rng = np.random.default_rng(utt.fetch_seed())
571-
a = pt.dmatrix()
572-
q = pt.dmatrix()
561+
a = pt.tensor(name="a", shape=shape)
562+
q = pt.tensor(name="q", shape=shape)
573563
f = function([a, q], [solve_continuous_lyapunov(a, q)])
574564

575-
A = rng.normal(size=(N, N))
576-
Q = rng.normal(size=(N, N))
565+
A = rng.normal(size=shape)
566+
Q = rng.normal(size=shape)
577567
X = f(A, Q)
578568

579-
Q_recovered = A @ X + X @ A.conj().T
569+
Q_recovered = vec_recover_Q(A, X, continuous=True)
580570

581571
np.testing.assert_allclose(Q_recovered.squeeze(), Q)
582572
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)
583573

584574

585-
def test_solve_discrete_are_forward():
575+
@pytest.mark.parametrize("add_batch_dim", [False, True])
576+
def test_solve_discrete_are_forward(add_batch_dim):
586577
# TEST CASE 4 : darex #1 -- taken from Scipy tests
587578
a, b, q, r = (
588579
np.array([[4, 3], [-4.5, -3.5]]),
589580
np.array([[1], [-1]]),
590581
np.array([[9, 6], [6, 4]]),
591582
np.array([[1]]),
592583
)
593-
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
584+
if add_batch_dim:
585+
a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r])
594586

595-
x = solve_discrete_are(a, b, q, r).eval()
596-
res = a.T.dot(x.dot(a)) - x + q
597-
res -= (
598-
a.conj()
599-
.T.dot(x.dot(b))
600-
.dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a)))
601-
)
587+
a, b, q, r = (pt.as_tensor_variable(x).astype(config.floatX) for x in [a, b, q, r])
588+
589+
x = solve_discrete_are(a, b, q, r)
590+
591+
# A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q
592+
def eval_fun(a, b, q, r, x):
593+
term_1 = a.T @ x @ a
594+
term_2 = a.T @ x @ b
595+
term_3 = pt.linalg.solve(r + b.T @ x @ b, b.T) @ x @ a
596+
597+
return term_1 - x - term_2 @ term_3 + q
598+
599+
res = pt.vectorize(eval_fun, "(m,m),(m,n),(m,m),(n,n),(m,m)->(m,m)")(a, b, q, r, x)
600+
res_np = res.eval()
602601

603602
atol = 1e-4 if config.floatX == "float32" else 1e-12
604-
np.testing.assert_allclose(res, np.zeros_like(res), atol=atol)
603+
np.testing.assert_allclose(res_np, np.zeros_like(res_np), atol=atol)
605604

606605

607-
def test_solve_discrete_are_grad():
606+
@pytest.mark.parametrize("add_batch_dim", [False, True])
607+
def test_solve_discrete_are_grad(add_batch_dim):
608608
a, b, q, r = (
609609
np.array([[4, 3], [-4.5, -3.5]]),
610610
np.array([[1], [-1]]),
611611
np.array([[9, 6], [6, 4]]),
612612
np.array([[1]]),
613613
)
614-
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
614+
if add_batch_dim:
615+
a, b, q, r = (np.stack([x] * 5) for x in [a, b, q, r])
615616

617+
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
616618
rng = np.random.default_rng(utt.fetch_seed())
617619

618620
# TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat

0 commit comments

Comments
 (0)