@@ -709,10 +709,11 @@ class GaussianStateSpace(TransformedDistribution):
709709
710710 .. math::
711711 \mathbf{z}_{t} &= \mathbf{A} \mathbf{z}_{t - 1} + \boldsymbol{\epsilon}_t\\
712- &=\ sum_{k=1} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_t ,
712+ &= \mathbf{A}^t \mathbf{z}_0 + \ sum_{k=1}^{t} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_k ,
713713
714714 where :math:`\mathbf{z}_t` is the state vector at step :math:`t`, :math:`\mathbf{A}`
715- is the transition matrix, and :math:`\boldsymbol\epsilon` is the innovation noise.
715+ is the transition matrix, :math:`\mathbf{z}_0` is the initial value, and
716+ :math:`\boldsymbol\epsilon` is the innovation noise.
716717
717718
718719 :param num_steps: Number of steps.
@@ -723,15 +724,19 @@ class GaussianStateSpace(TransformedDistribution):
723724 :math:`\boldsymbol\epsilon`.
724725 :param scale_tril: Scale matrix of the innovation noise
725726 :math:`\boldsymbol\epsilon`.
727+ :param initial_value: Initial state vector :math:`\mathbf{z}_0`. If ``None``,
728+ defaults to zero.
726729 """
727730
728731 arg_constraints = {
729732 "covariance_matrix" : constraints .positive_definite ,
730733 "precision_matrix" : constraints .positive_definite ,
731734 "scale_tril" : constraints .lower_cholesky ,
732735 "transition_matrix" : constraints .real_matrix ,
736+ "initial_value" : constraints .real_vector ,
733737 }
734738 support = constraints .real_matrix
739+ pytree_data_fields = ("transition_matrix" , "_initial_value" , "scale_tril" )
735740 pytree_aux_fields = ("num_steps" ,)
736741
737742 def __init__ (
@@ -741,6 +746,7 @@ def __init__(
741746 covariance_matrix : Optional [Array ] = None ,
742747 precision_matrix : Optional [Array ] = None ,
743748 scale_tril : Optional [Array ] = None ,
749+ initial_value : Optional [Array ] = None ,
744750 * ,
745751 validate_args : Optional [bool ] = None ,
746752 ) -> None :
@@ -752,6 +758,7 @@ def __init__(
752758 "`transition_matrix` argument should be a square matrix"
753759 )
754760 self .transition_matrix = transition_matrix
761+ self ._initial_value = initial_value
755762 # Expand the covariance/precision/scale matrices to the right number of steps.
756763 args = {
757764 "covariance_matrix" : covariance_matrix ,
@@ -766,23 +773,51 @@ def __init__(
766773 base_distribution = MultivariateNormal (** args )
767774 self .scale_tril = base_distribution .scale_tril [..., 0 , :, :]
768775 base_distribution = base_distribution .to_event (1 )
769- transform = RecursiveLinearTransform (transition_matrix )
776+
777+ # The base distribution must have at least the same batch shape as the initial
778+ # value.
779+ if initial_value is not None :
780+ batch_shape = initial_value .shape [:- 1 ]
781+ base_distribution = base_distribution .expand (batch_shape )
782+
783+ transform = RecursiveLinearTransform (
784+ transition_matrix , initial_value = initial_value
785+ )
770786 super ().__init__ (base_distribution , transform , validate_args = validate_args )
771787
788+ @property
789+ def initial_value (self ) -> Array :
790+ if self ._initial_value is None :
791+ return jnp .zeros (self .transition_matrix .shape [- 1 :])
792+ return self ._initial_value
793+
772794 @property
773795 def mean (self ) -> ArrayLike :
774- # The mean of the base distribution is zero and it has the right shape.
775- return self .base_dist .mean
796+ # If there's no initial value, the mean is zero (base distribution mean).
797+ if self ._initial_value is None :
798+ return self .base_dist .mean
799+
800+ # Otherwise, compute A^t @ z_0 for each time step t.
801+ # z_t = A @ z_{t-1} for the deterministic part with z_0 = initial_value
802+ def propagate (z , _ ):
803+ z_next = jnp .einsum ("...ij,...j->...i" , self .transition_matrix , z )
804+ return z_next , z_next
805+
806+ _ , means = scan (propagate , self .initial_value , jnp .arange (self .num_steps ))
807+ # means has shape (num_steps, ..., state_dim)
808+ # We need to move num_steps to axis -2 to match base_dist.mean shape
809+ return jnp .moveaxis (means , 0 , - 2 )
776810
777811 @property
778812 def variance (self ) -> ArrayLike :
779- # Given z_t = \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state
813+ # Given z_t = z_0 + \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state
780814 # vector at step t is E[z_t transpose(z_t)] = \sum_{k,k'}^t A^{t-k}
781815 # E[\epsilon_k transpose(\epsilon_{k'})] transpose(A^{t-k'}). We only have
782816 # contributions for k = k' because innovations at different steps are
783817 # independent such that E[z_t transpose(z_t)] = \sum_k^t A^{t-k} @
784- # @ covariance_matrix @ transpose(A^{t-k}). Using `scan` is an easy way to
785- # evaluate this expression.
818+ # @ covariance_matrix @ transpose(A^{t-k}). The initial value is deterministic,
819+ # and we don't need to consider it here. Using `scan` is an easy way to evaluate
820+ # this expression.
786821 _ , scale_tril = scan (
787822 lambda carry , _ : (self .transition_matrix @ carry , carry ),
788823 self .scale_tril ,
0 commit comments