|
1 | 1 | import logging
|
2 | 2 | import warnings
|
3 |
| -from typing import Union |
| 3 | +from typing import TYPE_CHECKING, Union |
4 | 4 |
|
5 | 5 | import numpy as np
|
6 | 6 | import scipy.linalg
|
| 7 | +from typing_extensions import Literal |
7 | 8 |
|
8 |
| -import pytensor.tensor |
| 9 | +import pytensor |
| 10 | +import pytensor.tensor as pt |
9 | 11 | from pytensor.graph.basic import Apply
|
10 | 12 | from pytensor.graph.op import Op
|
11 | 13 | from pytensor.tensor import as_tensor_variable
|
12 | 14 | from pytensor.tensor import basic as at
|
13 | 15 | from pytensor.tensor import math as atm
|
| 16 | +from pytensor.tensor.shape import reshape |
14 | 17 | from pytensor.tensor.type import matrix, tensor, vector
|
15 | 18 | from pytensor.tensor.var import TensorVariable
|
16 | 19 |
|
17 | 20 |
|
| 21 | +if TYPE_CHECKING: |
| 22 | + from pytensor.tensor import TensorLike |
| 23 | + |
| 24 | + |
18 | 25 | logger = logging.getLogger(__name__)
|
19 | 26 |
|
20 | 27 |
|
@@ -735,6 +742,159 @@ def perform(self, node, inputs, outputs):
|
735 | 742 |
|
736 | 743 | expm = Expm()
|
737 | 744 |
|
| 745 | + |
| 746 | +class SolveContinuousLyapunov(Op): |
| 747 | + __props__ = () |
| 748 | + |
| 749 | + def make_node(self, A, B): |
| 750 | + A = as_tensor_variable(A) |
| 751 | + B = as_tensor_variable(B) |
| 752 | + |
| 753 | + out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) |
| 754 | + X = pytensor.tensor.matrix(dtype=out_dtype) |
| 755 | + |
| 756 | + return pytensor.graph.basic.Apply(self, [A, B], [X]) |
| 757 | + |
| 758 | + def perform(self, node, inputs, output_storage): |
| 759 | + (A, B) = inputs |
| 760 | + X = output_storage[0] |
| 761 | + |
| 762 | + X[0] = scipy.linalg.solve_continuous_lyapunov(A, B) |
| 763 | + |
| 764 | + def infer_shape(self, fgraph, node, shapes): |
| 765 | + return [shapes[0]] |
| 766 | + |
| 767 | + def grad(self, inputs, output_grads): |
| 768 | + # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf |
| 769 | + # Note that they write the equation as AX + XA.H + Q = 0, while scipy uses AX + XA^H = Q, |
| 770 | + # so minor adjustments need to be made. |
| 771 | + A, Q = inputs |
| 772 | + (dX,) = output_grads |
| 773 | + |
| 774 | + X = self(A, Q) |
| 775 | + S = self(A.conj().T, -dX) # Eq 31, adjusted |
| 776 | + |
| 777 | + A_bar = S.dot(X.conj().T) + S.conj().T.dot(X) |
| 778 | + Q_bar = -S # Eq 29, adjusted |
| 779 | + |
| 780 | + return [A_bar, Q_bar] |
| 781 | + |
| 782 | + |
| 783 | +class BilinearSolveDiscreteLyapunov(Op): |
| 784 | + def make_node(self, A, B): |
| 785 | + A = as_tensor_variable(A) |
| 786 | + B = as_tensor_variable(B) |
| 787 | + |
| 788 | + out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype) |
| 789 | + X = pytensor.tensor.matrix(dtype=out_dtype) |
| 790 | + |
| 791 | + return pytensor.graph.basic.Apply(self, [A, B], [X]) |
| 792 | + |
| 793 | + def perform(self, node, inputs, output_storage): |
| 794 | + (A, B) = inputs |
| 795 | + X = output_storage[0] |
| 796 | + |
| 797 | + X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear") |
| 798 | + |
| 799 | + def infer_shape(self, fgraph, node, shapes): |
| 800 | + return [shapes[0]] |
| 801 | + |
| 802 | + def grad(self, inputs, output_grads): |
| 803 | + # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf |
| 804 | + A, Q = inputs |
| 805 | + (dX,) = output_grads |
| 806 | + |
| 807 | + X = self(A, Q) |
| 808 | + |
| 809 | + # Eq 41, note that it is not written as a proper Lyapunov equation |
| 810 | + S = self(A.conj().T, dX) |
| 811 | + |
| 812 | + A_bar = pytensor.tensor.linalg.matrix_dot( |
| 813 | + S, A, X.conj().T |
| 814 | + ) + pytensor.tensor.linalg.matrix_dot(S.conj().T, A, X) |
| 815 | + Q_bar = S |
| 816 | + return [A_bar, Q_bar] |
| 817 | + |
| 818 | + |
| 819 | +_solve_continuous_lyapunov = SolveContinuousLyapunov() |
| 820 | +_solve_bilinear_direct_lyapunov = BilinearSolveDiscreteLyapunov() |
| 821 | + |
| 822 | + |
| 823 | +def iscomplexobj(x): |
| 824 | + type_ = x.type |
| 825 | + dtype = type_.dtype |
| 826 | + return "complex" in dtype |
| 827 | + |
| 828 | + |
| 829 | +def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: |
| 830 | + A_ = as_tensor_variable(A) |
| 831 | + Q_ = as_tensor_variable(Q) |
| 832 | + |
| 833 | + if "complex" in A_.type.dtype: |
| 834 | + AA = kron(A_, A_.conj()) |
| 835 | + else: |
| 836 | + AA = kron(A_, A_) |
| 837 | + |
| 838 | + X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel()) |
| 839 | + return reshape(X, Q_.shape) |
| 840 | + |
| 841 | + |
| 842 | +def solve_discrete_lyapunov( |
| 843 | + A: "TensorLike", Q: "TensorLike", method: Literal["direct", "bilinear"] = "direct" |
| 844 | +) -> TensorVariable: |
| 845 | + """Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`. |
| 846 | +
|
| 847 | + Parameters |
| 848 | + ---------- |
| 849 | + A |
| 850 | + Square matrix of shape N x N; must have the same shape as Q |
| 851 | + Q |
| 852 | + Square matrix of shape N x N; must have the same shape as A |
| 853 | + method |
| 854 | + Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"`` |
| 855 | + solves the problem directly via matrix inversion. This has a pure |
| 856 | + PyTensor implementation and can thus be cross-compiled to supported |
| 857 | + backends, and should be preferred when ``N`` is not large. The direct |
| 858 | + method scales poorly with the size of ``N``, and the bilinear can be |
| 859 | + used in these cases. |
| 860 | +
|
| 861 | + Returns |
| 862 | + ------- |
| 863 | + Square matrix of shape ``N x N``, representing the solution to the |
| 864 | + Lyapunov equation |
| 865 | +
|
| 866 | + """ |
| 867 | + if method not in ["direct", "bilinear"]: |
| 868 | + raise ValueError( |
| 869 | + f'Parameter "method" must be one of "direct" or "bilinear", found {method}' |
| 870 | + ) |
| 871 | + |
| 872 | + if method == "direct": |
| 873 | + return _direct_solve_discrete_lyapunov(A, Q) |
| 874 | + if method == "bilinear": |
| 875 | + return _solve_bilinear_direct_lyapunov(A, Q) |
| 876 | + |
| 877 | + |
| 878 | +def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable: |
| 879 | + """Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`. |
| 880 | +
|
| 881 | + Parameters |
| 882 | + ---------- |
| 883 | + A |
| 884 | + Square matrix of shape ``N x N``; must have the same shape as `Q`. |
| 885 | + Q |
| 886 | + Square matrix of shape ``N x N``; must have the same shape as `A`. |
| 887 | +
|
| 888 | + Returns |
| 889 | + ------- |
| 890 | + Square matrix of shape ``N x N``, representing the solution to the |
| 891 | + Lyapunov equation |
| 892 | +
|
| 893 | + """ |
| 894 | + |
| 895 | + return _solve_continuous_lyapunov(A, Q) |
| 896 | + |
| 897 | + |
738 | 898 | __all__ = [
|
739 | 899 | "cholesky",
|
740 | 900 | "solve",
|
|
0 commit comments