Skip to content

Commit 288a3f3

Browse files
Add Op corresponding to scipy.linalg.solve_discrete_are (#417)
* Add pytensor function corresponding to * Add pytensor function corresponding to * Cast numpy to node output dtype, rather than depending on config.floatX Change output type hint back to * Cast numpy to node output dtype, rather than depending on config.floatX Change output type hint back to * Use rather than for output equality tests
1 parent 34eaaa5 commit 288a3f3

File tree

2 files changed

+148
-20
lines changed

2 files changed

+148
-20
lines changed

pytensor/tensor/slinalg.py

Lines changed: 95 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import typing
23
import warnings
34
from typing import TYPE_CHECKING, Literal, Union
45

@@ -12,6 +13,7 @@
1213
from pytensor.tensor import as_tensor_variable
1314
from pytensor.tensor import basic as at
1415
from pytensor.tensor import math as atm
16+
from pytensor.tensor.nlinalg import matrix_dot
1517
from pytensor.tensor.shape import reshape
1618
from pytensor.tensor.type import matrix, tensor, vector
1719
from pytensor.tensor.var import TensorVariable
@@ -321,9 +323,6 @@ def L_op(self, inputs, outputs, output_gradients):
321323
return res
322324

323325

324-
solvetriangular = SolveTriangular()
325-
326-
327326
def solve_triangular(
328327
a: TensorVariable,
329328
b: TensorVariable,
@@ -397,9 +396,6 @@ def perform(self, node, inputs, outputs):
397396
)
398397

399398

400-
solve = Solve()
401-
402-
403399
def solve(a, b, assume_a="gen", lower=False, check_finite=True):
404400
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
405401
@@ -748,13 +744,9 @@ def grad(self, inputs, output_grads):
748744

749745

750746
_solve_continuous_lyapunov = SolveContinuousLyapunov()
751-
_solve_bilinear_direct_lyapunov = BilinearSolveDiscreteLyapunov()
752-
753-
754-
def iscomplexobj(x):
755-
type_ = x.type
756-
dtype = type_.dtype
757-
return "complex" in dtype
747+
_solve_bilinear_direct_lyapunov = typing.cast(
748+
typing.Callable, BilinearSolveDiscreteLyapunov()
749+
)
758750

759751

760752
def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
@@ -767,7 +759,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV
767759
AA = kron(A_, A_)
768760

769761
X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
770-
return reshape(X, Q_.shape)
762+
return typing.cast(TensorVariable, reshape(X, Q_.shape))
771763

772764

773765
def solve_discrete_lyapunov(
@@ -803,7 +795,7 @@ def solve_discrete_lyapunov(
803795
if method == "direct":
804796
return _direct_solve_discrete_lyapunov(A, Q)
805797
if method == "bilinear":
806-
return _solve_bilinear_direct_lyapunov(A, Q)
798+
return typing.cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
807799

808800

809801
def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
@@ -823,7 +815,90 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
823815
824816
"""
825817

826-
return _solve_continuous_lyapunov(A, Q)
818+
return typing.cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
819+
820+
821+
class SolveDiscreteARE(pt.Op):
822+
__props__ = ("enforce_Q_symmetric",)
823+
824+
def __init__(self, enforce_Q_symmetric=False):
825+
self.enforce_Q_symmetric = enforce_Q_symmetric
826+
827+
def make_node(self, A, B, Q, R):
828+
A = as_tensor_variable(A)
829+
B = as_tensor_variable(B)
830+
Q = as_tensor_variable(Q)
831+
R = as_tensor_variable(R)
832+
833+
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, Q.dtype, R.dtype)
834+
X = pytensor.tensor.matrix(dtype=out_dtype)
835+
836+
return pytensor.graph.basic.Apply(self, [A, B, Q, R], [X])
837+
838+
def perform(self, node, inputs, output_storage):
839+
A, B, Q, R = inputs
840+
X = output_storage[0]
841+
842+
if self.enforce_Q_symmetric:
843+
Q = 0.5 * (Q + Q.T)
844+
845+
X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(
846+
node.outputs[0].type.dtype
847+
)
848+
849+
def infer_shape(self, fgraph, node, shapes):
850+
return [shapes[0]]
851+
852+
def grad(self, inputs, output_grads):
853+
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
854+
A, B, Q, R = inputs
855+
856+
(dX,) = output_grads
857+
X = self(A, B, Q, R)
858+
859+
K_inner = R + pt.linalg.matrix_dot(B.T, X, B)
860+
K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0]))
861+
K = matrix_dot(K_inner_inv, B.T, X, A)
862+
863+
A_tilde = A - B.dot(K)
864+
865+
dX_symm = 0.5 * (dX + dX.T)
866+
S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(dX.type.dtype)
867+
868+
A_bar = 2 * matrix_dot(X, A_tilde, S)
869+
B_bar = -2 * matrix_dot(X, A_tilde, S, K.T)
870+
Q_bar = S
871+
R_bar = matrix_dot(K, S, K.T)
872+
873+
return [A_bar, B_bar, Q_bar, R_bar]
874+
875+
876+
def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
877+
"""
878+
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
879+
880+
Parameters
881+
----------
882+
A: ArrayLike
883+
Square matrix of shape M x M
884+
B: ArrayLike
885+
Square matrix of shape M x M
886+
Q: ArrayLike
887+
Symmetric square matrix of shape M x M
888+
R: ArrayLike
889+
Square matrix of shape N x N
890+
enforce_Q_symmetric: bool
891+
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
892+
893+
Returns
894+
-------
895+
X: pt.matrix
896+
Square matrix of shape M x M, representing the solution to the DARE
897+
"""
898+
899+
return typing.cast(
900+
TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R)
901+
)
827902

828903

829904
__all__ = [
@@ -832,4 +907,8 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
832907
"eigvalsh",
833908
"kron",
834909
"expm",
910+
"solve_discrete_lyapunov",
911+
"solve_continuous_lyapunov",
912+
"solve_discrete_are",
913+
"solve_triangular",
835914
]

tests/tensor/test_slinalg.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
kron,
2323
solve,
2424
solve_continuous_lyapunov,
25+
solve_discrete_are,
2526
solve_discrete_lyapunov,
2627
solve_triangular,
2728
)
@@ -532,7 +533,7 @@ def test_perform(self):
532533
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
533534
else:
534535
scipy_val = scipy.linalg.kron(a, b)
535-
utt.assert_allclose(out, scipy_val)
536+
np.testing.assert_allclose(out, scipy_val)
536537

537538
def test_numpy_2d(self):
538539
for shp0 in [(2, 3)]:
@@ -564,7 +565,10 @@ def test_solve_discrete_lyapunov_via_direct_real():
564565
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
565566

566567

568+
@pytest.mark.filterwarnings("ignore::UserWarning")
567569
def test_solve_discrete_lyapunov_via_direct_complex():
570+
# Conj doesn't have C-op; filter the warning.
571+
568572
N = 5
569573
rng = np.random.default_rng(utt.fetch_seed())
570574
a = pt.zmatrix()
@@ -574,7 +578,7 @@ def test_solve_discrete_lyapunov_via_direct_complex():
574578
A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j
575579
Q = rng.normal(size=(N, N))
576580
X = f(A, Q)
577-
assert np.allclose(A @ X @ A.conj().T - X + Q, 0.0)
581+
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
578582

579583
# TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented.
580584
# utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
@@ -591,8 +595,8 @@ def test_solve_discrete_lyapunov_via_bilinear():
591595
Q = rng.normal(size=(N, N))
592596

593597
X = f(A, Q)
594-
assert np.allclose(A @ X @ A.conj().T - X + Q, 0.0)
595598

599+
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
596600
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
597601

598602

@@ -607,6 +611,51 @@ def test_solve_continuous_lyapunov():
607611
Q = rng.normal(size=(N, N))
608612
X = f(A, Q)
609613

610-
assert np.allclose(A @ X + X @ A.conj().T, Q)
614+
Q_recovered = A @ X + X @ A.conj().T
611615

616+
np.testing.assert_allclose(Q_recovered.squeeze(), Q)
612617
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)
618+
619+
620+
def test_solve_discrete_are_forward():
621+
# TEST CASE 4 : darex #1 -- taken from Scipy tests
622+
a, b, q, r = (
623+
np.array([[4, 3], [-4.5, -3.5]]),
624+
np.array([[1], [-1]]),
625+
np.array([[9, 6], [6, 4]]),
626+
np.array([[1]]),
627+
)
628+
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
629+
630+
x = solve_discrete_are(a, b, q, r).eval()
631+
res = a.T.dot(x.dot(a)) - x + q
632+
res -= (
633+
a.conj()
634+
.T.dot(x.dot(b))
635+
.dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a)))
636+
)
637+
638+
atol = 1e-4 if config.floatX == "float32" else 1e-12
639+
np.testing.assert_allclose(res, np.zeros_like(res), atol=atol)
640+
641+
642+
def test_solve_discrete_are_grad():
643+
a, b, q, r = (
644+
np.array([[4, 3], [-4.5, -3.5]]),
645+
np.array([[1], [-1]]),
646+
np.array([[9, 6], [6, 4]]),
647+
np.array([[1]]),
648+
)
649+
a, b, q, r = (x.astype(config.floatX) for x in [a, b, q, r])
650+
651+
rng = np.random.default_rng(utt.fetch_seed())
652+
653+
# TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat
654+
atol = 1e-4 if config.floatX == "float32" else 1e-12
655+
656+
utt.verify_grad(
657+
functools.partial(solve_discrete_are, enforce_Q_symmetric=True),
658+
pt=[a, b, q, r],
659+
rng=rng,
660+
abs_tol=atol,
661+
)

0 commit comments

Comments
 (0)