Skip to content

Commit 50a9748

Browse files
committed
Vectorized full state space model
1 parent a91f959 commit 50a9748

File tree

5 files changed

+78
-6
lines changed

5 files changed

+78
-6
lines changed

pymc_extras/statespace/core/representation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ class PytensorRepresentation:
6060
6161
.. math::
6262
\begin{align}
63-
x_t &= A_t x_{t-1} + c_t + R_t \varepsilon_t \tag{1} \\
63+
x_t &= T_t x_{t-1} + c_t + R_t \varepsilon_t \tag{1} \\
6464
y_t &= Z_t x_t + d_t + \eta_t \tag{2} \\
6565
\end{align}
6666
6767
Where :math:`\{x_t\}_{t=0}^T` is a trajectory of hidden states, and :math:`\{y_t\}_{t=0}^T` is a trajectory of
68-
observable states. Equation 1 is known as the "state transition equation", while describes how the system evolves
68+
observable states. Equation 1 is known as the "state transition equation", which describes how the system evolves
6969
over time. Equation 2 is the "observation equation", and maps the latent state processes to observed data.
7070
The system is Gaussian when the innovations, :math:`\varepsilon_t`, and the measurement errors, :math:`\eta_t`,
7171
are normally distributed. The definition is completed by specification of these distributions, as

pymc_extras/statespace/core/statespace.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pymc.util import RandomState
1616
from pytensor import Variable, graph_replace
1717
from pytensor.compile import get_mode
18+
from pytensor.graph.replace import vectorize_graph
1819
from rich.box import SIMPLE_HEAD
1920
from rich.console import Console
2021
from rich.table import Table
@@ -37,6 +38,7 @@
3738
FILTER_OUTPUT_DIMS,
3839
FILTER_OUTPUT_TYPES,
3940
JITTER_DEFAULT,
41+
LONG_MATRIX_NAMES,
4042
MATRIX_DIMS,
4143
MATRIX_NAMES,
4244
OBS_STATE_DIM,
@@ -46,7 +48,6 @@
4648
VECTOR_VALUED,
4749
)
4850
from pymc_extras.statespace.utils.data_tools import register_data_with_pymc
49-
from pytensor.graph.replace import vectorize_graph
5051

5152
_log = logging.getLogger("pymc.experimental.statespace")
5253

@@ -61,7 +62,7 @@
6162
def _validate_filter_arg(filter_arg):
6263
if filter_arg.lower() not in FILTER_OUTPUT_TYPES:
6364
raise ValueError(
64-
f'filter_output should be one of {", ".join(FILTER_OUTPUT_TYPES)}, received {filter_arg}'
65+
f"filter_output should be one of {', '.join(FILTER_OUTPUT_TYPES)}, received {filter_arg}"
6566
)
6667

6768

@@ -728,6 +729,8 @@ def _insert_random_variables(self):
728729

729730
replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()}
730731
self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict)
732+
for name, matrix in zip(LONG_MATRIX_NAMES, self.subbed_ssm):
733+
matrix.name = name
731734

732735
def _insert_data_variables(self):
733736
"""

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
split_vars_into_seq_and_nonseq,
1818
stabilize,
1919
)
20-
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL, ALL_KF_OUTPUT_NAMES
20+
from pymc_extras.statespace.utils.constants import ALL_KF_OUTPUT_NAMES, JITTER_DEFAULT, MISSING_FILL
2121

2222
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
2323
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
@@ -85,14 +85,23 @@ def _make_gufunc_signature(self, inputs):
8585
"data": (time, obs),
8686
"a0": (states,),
8787
"x0": (states,),
88+
"initial_state": (states,),
8889
"P0": (states, states),
90+
"initial_state_cov": (states, states),
8991
"c": (states,),
92+
"state_intercept": (states,),
9093
"d": (obs,),
94+
"obs_intercept": (obs,),
9195
"T": (states, states),
96+
"transition": (states, states),
9297
"Z": (obs, states),
98+
"design": (obs, states),
9399
"R": (states, exog),
100+
"selection": (states, exog),
94101
"H": (obs, obs),
102+
"obs_cov": (obs, obs),
95103
"Q": (exog, exog),
104+
"state_cov": (exog, exog),
96105
"filtered_states": (time, states),
97106
"filtered_covariances": (time, states, states),
98107
"predicted_states": (time, states),
@@ -322,6 +331,7 @@ def build_graph(
322331
cov_jitter=cov_jitter,
323332
)
324333
filter_outputs = pt.vectorize(fn, signature=signature)(data, a0, P0, c, d, T, Z, R, H, Q)
334+
# filter_outputs = fn(data, a0, P0, c, d, T, Z, R, H, Q)
325335
for output, name in zip(filter_outputs, ALL_KF_OUTPUT_NAMES):
326336
output.name = name
327337

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from functools import partial
2+
13
import pytensor
24
import pytensor.tensor as pt
3-
from functools import partial
5+
46
from pytensor.compile import get_mode
57
from pytensor.tensor.nlinalg import matrix_dot
8+
69
from pymc_extras.statespace.filters.utilities import (
710
quad_form_sym,
811
split_vars_into_seq_and_nonseq,
@@ -74,14 +77,23 @@ def _make_gufunc_signature(self, inputs):
7477
"data": (time, obs),
7578
"a0": (states,),
7679
"x0": (states,),
80+
"initial_state": (states,),
7781
"P0": (states, states),
82+
"initial_state_cov": (states, states),
7883
"c": (states,),
84+
"state_intercept": (states,),
7985
"d": (obs,),
86+
"obs_intercept": (obs,),
8087
"T": (states, states),
88+
"transition": (states, states),
8189
"Z": (obs, states),
90+
"design": (obs, states),
8291
"R": (states, exog),
92+
"selection": (states, exog),
8393
"H": (obs, obs),
94+
"obs_cov": (obs, obs),
8495
"Q": (exog, exog),
96+
"state_cov": (exog, exog),
8597
"filtered_states": (time, states),
8698
"filtered_covariances": (time, states, states),
8799
"predicted_states": (time, states),
@@ -166,6 +178,7 @@ def build_graph(
166178
cov_jitter=cov_jitter,
167179
)
168180
return pt.vectorize(fn, signature=signature)(T, R, Q, filtered_states, filtered_covariances)
181+
# return fn(T, R, Q, filtered_states, filtered_covariances)
169182

170183
def smoother_step(self, *args):
171184
a, P, a_smooth, P_smooth, T, R, Q = self.unpack_args(args)

tests/statespace/test_statespace.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from numpy.testing import assert_allclose
1212

1313
from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
14+
from pymc_extras.statespace.filters.kalman_filter import StandardFilter
15+
from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
1416
from pymc_extras.statespace.models import structural as st
1517
from pymc_extras.statespace.models.utilities import make_default_coords
1618
from pymc_extras.statespace.utils.constants import (
@@ -878,3 +880,47 @@ def test_insert_batched_rvs(ss_mod, batch_size):
878880
ss_mod._insert_random_variables()
879881
matrices = ss_mod.unpack_statespace()
880882
assert matrices[4].type.shape == (*batch_size, 2, 2)
883+
884+
885+
@pytest.mark.parametrize("batch_size", [(10,), (10, 3, 5)])
886+
def test_insert_batched_rvs_in_kf(ss_mod, batch_size):
887+
data = pt.as_tensor(np.random.normal(size=(*batch_size, 7, 1)).astype(floatX))
888+
data.name = "data"
889+
kf = StandardFilter()
890+
891+
with pm.Model():
892+
rho = pm.Normal("rho", shape=batch_size)
893+
zeta = pm.Normal("zeta", shape=batch_size)
894+
ss_mod._insert_random_variables()
895+
896+
matrices = x0, P0, c, d, T, Z, R, H, Q = ss_mod.unpack_statespace()
897+
outputs = kf.build_graph(data, *matrices)
898+
899+
logp = outputs.pop(-1)
900+
states, covs = outputs[:3], outputs[3:]
901+
filtered_states, predicted_states, observed_states = states
902+
filtered_covariances, predicted_covariances, observed_covariances = covs
903+
904+
assert logp.type.shape == (*batch_size, 7)
905+
assert filtered_states.type.shape == (*batch_size, 7, 2)
906+
assert predicted_states.type.shape == (*batch_size, 7, 2)
907+
assert observed_states.type.shape == (*batch_size, 7, 1)
908+
assert filtered_covariances.type.shape == (*batch_size, 7, 2, 2)
909+
assert predicted_covariances.type.shape == (*batch_size, 7, 2, 2)
910+
assert observed_covariances.type.shape == (*batch_size, 7, 1, 1)
911+
912+
ks = KalmanSmoother()
913+
smoothed_states, smoothed_covariances = ks.build_graph(
914+
T, R, Q, filtered_states, filtered_covariances
915+
)
916+
assert smoothed_states.type.shape == (
917+
*batch_size,
918+
None,
919+
2,
920+
) # TODO: why do we lose the time dimension here?
921+
assert smoothed_covariances.type.shape == (
922+
*batch_size,
923+
None,
924+
2,
925+
2,
926+
) # TODO: why do we lose the time dimension here?

0 commit comments

Comments
 (0)