Skip to content

Commit 931d67d

Browse files
Add initial value to Gaussian state space distribution (fixes #2098). (#2104)
1 parent 301117c commit 931d67d

File tree

3 files changed

+71
-8
lines changed

3 files changed

+71
-8
lines changed

numpyro/distributions/constraints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"multinomial",
4646
"nonnegative",
4747
"nonnegative_integer",
48+
"ordered_vector",
4849
"positive",
4950
"positive_definite",
5051
"positive_definite_circulant_vector",

numpyro/distributions/continuous.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -709,10 +709,11 @@ class GaussianStateSpace(TransformedDistribution):
709709
710710
.. math::
711711
\mathbf{z}_{t} &= \mathbf{A} \mathbf{z}_{t - 1} + \boldsymbol{\epsilon}_t\\
712-
&=\sum_{k=1} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_t,
712+
&= \mathbf{A}^t \mathbf{z}_0 + \sum_{k=1}^{t} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_k,
713713
714714
where :math:`\mathbf{z}_t` is the state vector at step :math:`t`, :math:`\mathbf{A}`
715-
is the transition matrix, and :math:`\boldsymbol\epsilon` is the innovation noise.
715+
is the transition matrix, :math:`\mathbf{z}_0` is the initial value, and
716+
:math:`\boldsymbol\epsilon` is the innovation noise.
716717
717718
718719
:param num_steps: Number of steps.
@@ -723,15 +724,19 @@ class GaussianStateSpace(TransformedDistribution):
723724
:math:`\boldsymbol\epsilon`.
724725
:param scale_tril: Scale matrix of the innovation noise
725726
:math:`\boldsymbol\epsilon`.
727+
:param initial_value: Initial state vector :math:`\mathbf{z}_0`. If ``None``,
728+
defaults to zero.
726729
"""
727730

728731
arg_constraints = {
729732
"covariance_matrix": constraints.positive_definite,
730733
"precision_matrix": constraints.positive_definite,
731734
"scale_tril": constraints.lower_cholesky,
732735
"transition_matrix": constraints.real_matrix,
736+
"initial_value": constraints.real_vector,
733737
}
734738
support = constraints.real_matrix
739+
pytree_data_fields = ("transition_matrix", "_initial_value", "scale_tril")
735740
pytree_aux_fields = ("num_steps",)
736741

737742
def __init__(
@@ -741,6 +746,7 @@ def __init__(
741746
covariance_matrix: Optional[Array] = None,
742747
precision_matrix: Optional[Array] = None,
743748
scale_tril: Optional[Array] = None,
749+
initial_value: Optional[Array] = None,
744750
*,
745751
validate_args: Optional[bool] = None,
746752
) -> None:
@@ -752,6 +758,7 @@ def __init__(
752758
"`transition_matrix` argument should be a square matrix"
753759
)
754760
self.transition_matrix = transition_matrix
761+
self._initial_value = initial_value
755762
# Expand the covariance/precision/scale matrices to the right number of steps.
756763
args = {
757764
"covariance_matrix": covariance_matrix,
@@ -766,23 +773,51 @@ def __init__(
766773
base_distribution = MultivariateNormal(**args)
767774
self.scale_tril = base_distribution.scale_tril[..., 0, :, :]
768775
base_distribution = base_distribution.to_event(1)
769-
transform = RecursiveLinearTransform(transition_matrix)
776+
777+
# The base distribution must have at least the same batch shape as the initial
778+
# value.
779+
if initial_value is not None:
780+
batch_shape = initial_value.shape[:-1]
781+
base_distribution = base_distribution.expand(batch_shape)
782+
783+
transform = RecursiveLinearTransform(
784+
transition_matrix, initial_value=initial_value
785+
)
770786
super().__init__(base_distribution, transform, validate_args=validate_args)
771787

788+
@property
789+
def initial_value(self) -> Array:
790+
if self._initial_value is None:
791+
return jnp.zeros(self.transition_matrix.shape[-1:])
792+
return self._initial_value
793+
772794
@property
773795
def mean(self) -> ArrayLike:
774-
# The mean of the base distribution is zero and it has the right shape.
775-
return self.base_dist.mean
796+
# If there's no initial value, the mean is zero (base distribution mean).
797+
if self._initial_value is None:
798+
return self.base_dist.mean
799+
800+
# Otherwise, compute A^t @ z_0 for each time step t.
801+
# z_t = A @ z_{t-1} for the deterministic part with z_0 = initial_value
802+
def propagate(z, _):
803+
z_next = jnp.einsum("...ij,...j->...i", self.transition_matrix, z)
804+
return z_next, z_next
805+
806+
_, means = scan(propagate, self.initial_value, jnp.arange(self.num_steps))
807+
# means has shape (num_steps, ..., state_dim)
808+
# We need to move num_steps to axis -2 to match base_dist.mean shape
809+
return jnp.moveaxis(means, 0, -2)
776810

777811
@property
778812
def variance(self) -> ArrayLike:
779-
# Given z_t = \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state
813+
# Given z_t = z_0 + \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state
780814
# vector at step t is E[z_t transpose(z_t)] = \sum_{k,k'}^t A^{t-k}
781815
# E[\epsilon_k transpose(\epsilon_{k'})] transpose(A^{t-k'}). We only have
782816
# contributions for k = k' because innovations at different steps are
783817
# independent such that E[z_t transpose(z_t)] = \sum_k^t A^{t-k} @
784-
# @ covariance_matrix @ transpose(A^{t-k}). Using `scan` is an easy way to
785-
# evaluate this expression.
818+
# @ covariance_matrix @ transpose(A^{t-k}). The initial value is deterministic,
819+
# and we don't need to consider it here. Using `scan` is an easy way to evaluate
820+
# this expression.
786821
_, scale_tril = scan(
787822
lambda carry, _: (self.transition_matrix @ carry, carry),
788823
self.scale_tril,

test/test_distributions.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,33 @@ def get_sp_dist(jax_dist):
631631
np.array([[0.8, 0.2], [-0.1, 1.1]]),
632632
np.array([0.1, 0.3, 0.25])[:, None, None] * np.array([[0.8, 0.2], [0.2, 0.7]]),
633633
),
634+
T(
635+
dist.GaussianStateSpace,
636+
5,
637+
np.array([[0.8, 0.1], [-0.1, 0.9]]),
638+
None,
639+
None,
640+
np.array([[0.5, 0.0], [0.0, 0.5]]),
641+
np.array([1.0, 2.0]),
642+
),
643+
T(
644+
dist.GaussianStateSpace,
645+
5,
646+
np.array([[0.8, 0.1], [-0.1, 0.9]]),
647+
None,
648+
None,
649+
np.array([[0.5, 0.0], [0.0, 0.5]]),
650+
np.array([[1.0, 2.0], [0.5, 1.5], [-1.0, 0.0]]),
651+
),
652+
T(
653+
dist.GaussianStateSpace,
654+
4,
655+
np.array([[0.9, 0.0], [0.0, 0.9]]),
656+
None,
657+
None,
658+
np.array([[0.3, 0.0], [0.0, 0.3]]),
659+
np.array([[[1.0, 0.0]], [[0.0, 1.0]]]),
660+
),
634661
pytest.param(
635662
*T(
636663
dist.GaussianCopulaBeta,

0 commit comments

Comments
 (0)