Skip to content

Commit 843b3a9

Browse files
committed
Vectorized full state space model
1 parent 9b9418d commit 843b3a9

File tree

5 files changed

+79
-6
lines changed

5 files changed

+79
-6
lines changed

pymc_extras/statespace/core/representation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ class PytensorRepresentation:
6060
6161
.. math::
6262
\begin{align}
63-
x_t &= A_t x_{t-1} + c_t + R_t \varepsilon_t \tag{1} \\
63+
x_t &= T_t x_{t-1} + c_t + R_t \varepsilon_t \tag{1} \\
6464
y_t &= Z_t x_t + d_t + \eta_t \tag{2} \\
6565
\end{align}
6666
6767
Where :math:`\{x_t\}_{t=0}^T` is a trajectory of hidden states, and :math:`\{y_t\}_{t=0}^T` is a trajectory of
68-
observable states. Equation 1 is known as the "state transition equation", while describes how the system evolves
68+
observable states. Equation 1 is known as the "state transition equation", which describes how the system evolves
6969
over time. Equation 2 is the "observation equation", and maps the latent state processes to observed data.
7070
The system is Gaussian when the innovations, :math:`\varepsilon_t`, and the measurement errors, :math:`\eta_t`,
7171
are normally distributed. The definition is completed by specification of these distributions, as

pymc_extras/statespace/core/statespace.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from pymc.model.transform.optimization import freeze_dims_and_data
1616
from pymc.util import RandomState
1717
from pytensor import Variable, graph_replace
18+
from pytensor.compile import get_mode
19+
from pytensor.graph.replace import vectorize_graph
1820
from rich.box import SIMPLE_HEAD
1921
from rich.console import Console
2022
from rich.table import Table
@@ -37,6 +39,7 @@
3739
FILTER_OUTPUT_DIMS,
3840
FILTER_OUTPUT_TYPES,
3941
JITTER_DEFAULT,
42+
LONG_MATRIX_NAMES,
4043
MATRIX_DIMS,
4144
MATRIX_NAMES,
4245
OBS_STATE_DIM,
@@ -46,7 +49,6 @@
4649
VECTOR_VALUED,
4750
)
4851
from pymc_extras.statespace.utils.data_tools import register_data_with_pymc
49-
from pytensor.graph.replace import vectorize_graph
5052

5153
_log = logging.getLogger("pymc.experimental.statespace")
5254

@@ -61,7 +63,7 @@
6163
def _validate_filter_arg(filter_arg):
6264
if filter_arg.lower() not in FILTER_OUTPUT_TYPES:
6365
raise ValueError(
64-
f'filter_output should be one of {", ".join(FILTER_OUTPUT_TYPES)}, received {filter_arg}'
66+
f"filter_output should be one of {', '.join(FILTER_OUTPUT_TYPES)}, received {filter_arg}"
6567
)
6668

6769

@@ -736,6 +738,8 @@ def _insert_random_variables(self):
736738

737739
replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()}
738740
self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict)
741+
for name, matrix in zip(LONG_MATRIX_NAMES, self.subbed_ssm):
742+
matrix.name = name
739743

740744
def _insert_data_variables(self):
741745
"""

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
split_vars_into_seq_and_nonseq,
1717
stabilize,
1818
)
19-
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL, ALL_KF_OUTPUT_NAMES
19+
from pymc_extras.statespace.utils.constants import ALL_KF_OUTPUT_NAMES, JITTER_DEFAULT, MISSING_FILL
2020

2121
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
2222
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
@@ -75,14 +75,23 @@ def _make_gufunc_signature(self, inputs):
7575
"data": (time, obs),
7676
"a0": (states,),
7777
"x0": (states,),
78+
"initial_state": (states,),
7879
"P0": (states, states),
80+
"initial_state_cov": (states, states),
7981
"c": (states,),
82+
"state_intercept": (states,),
8083
"d": (obs,),
84+
"obs_intercept": (obs,),
8185
"T": (states, states),
86+
"transition": (states, states),
8287
"Z": (obs, states),
88+
"design": (obs, states),
8389
"R": (states, exog),
90+
"selection": (states, exog),
8491
"H": (obs, obs),
92+
"obs_cov": (obs, obs),
8593
"Q": (exog, exog),
94+
"state_cov": (exog, exog),
8695
"filtered_states": (time, states),
8796
"filtered_covariances": (time, states, states),
8897
"predicted_states": (time, states),
@@ -306,6 +315,7 @@ def build_graph(
306315
cov_jitter=cov_jitter,
307316
)
308317
filter_outputs = pt.vectorize(fn, signature=signature)(data, a0, P0, c, d, T, Z, R, H, Q)
318+
# filter_outputs = fn(data, a0, P0, c, d, T, Z, R, H, Q)
309319
for output, name in zip(filter_outputs, ALL_KF_OUTPUT_NAMES):
310320
output.name = name
311321

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from functools import partial
2+
13
import pytensor
24
import pytensor.tensor as pt
3-
from functools import partial
5+
46
from pytensor.compile import get_mode
57
from pytensor.tensor.nlinalg import matrix_dot
8+
69
from pymc_extras.statespace.filters.utilities import (
710
quad_form_sym,
811
split_vars_into_seq_and_nonseq,
@@ -73,14 +76,23 @@ def _make_gufunc_signature(self, inputs):
7376
"data": (time, obs),
7477
"a0": (states,),
7578
"x0": (states,),
79+
"initial_state": (states,),
7680
"P0": (states, states),
81+
"initial_state_cov": (states, states),
7782
"c": (states,),
83+
"state_intercept": (states,),
7884
"d": (obs,),
85+
"obs_intercept": (obs,),
7986
"T": (states, states),
87+
"transition": (states, states),
8088
"Z": (obs, states),
89+
"design": (obs, states),
8190
"R": (states, exog),
91+
"selection": (states, exog),
8292
"H": (obs, obs),
93+
"obs_cov": (obs, obs),
8394
"Q": (exog, exog),
95+
"state_cov": (exog, exog),
8496
"filtered_states": (time, states),
8597
"filtered_covariances": (time, states, states),
8698
"predicted_states": (time, states),
@@ -163,6 +175,7 @@ def build_graph(
163175
cov_jitter=cov_jitter,
164176
)
165177
return pt.vectorize(fn, signature=signature)(T, R, Q, filtered_states, filtered_covariances)
178+
# return fn(T, R, Q, filtered_states, filtered_covariances)
166179

167180
def smoother_step(self, *args):
168181
a, P, a_smooth, P_smooth, T, R, Q = self.unpack_args(args)

tests/statespace/core/test_statespace.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from pytensor.graph.basic import graph_inputs
1515

1616
from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
17+
from pymc_extras.statespace.filters.kalman_filter import StandardFilter
18+
from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
1719
from pymc_extras.statespace.models import structural as st
1820
from pymc_extras.statespace.models.utilities import make_default_coords
1921
from pymc_extras.statespace.utils.constants import (
@@ -1025,3 +1027,47 @@ def test_insert_batched_rvs(ss_mod, batch_size):
10251027
ss_mod._insert_random_variables()
10261028
matrices = ss_mod.unpack_statespace()
10271029
assert matrices[4].type.shape == (*batch_size, 2, 2)
1030+
1031+
1032+
@pytest.mark.parametrize("batch_size", [(10,), (10, 3, 5)])
1033+
def test_insert_batched_rvs_in_kf(ss_mod, batch_size):
1034+
data = pt.as_tensor(np.random.normal(size=(*batch_size, 7, 1)).astype(floatX))
1035+
data.name = "data"
1036+
kf = StandardFilter()
1037+
1038+
with pm.Model():
1039+
rho = pm.Normal("rho", shape=batch_size)
1040+
zeta = pm.Normal("zeta", shape=batch_size)
1041+
ss_mod._insert_random_variables()
1042+
1043+
matrices = x0, P0, c, d, T, Z, R, H, Q = ss_mod.unpack_statespace()
1044+
outputs = kf.build_graph(data, *matrices)
1045+
1046+
logp = outputs.pop(-1)
1047+
states, covs = outputs[:3], outputs[3:]
1048+
filtered_states, predicted_states, observed_states = states
1049+
filtered_covariances, predicted_covariances, observed_covariances = covs
1050+
1051+
assert logp.type.shape == (*batch_size, 7)
1052+
assert filtered_states.type.shape == (*batch_size, 7, 2)
1053+
assert predicted_states.type.shape == (*batch_size, 7, 2)
1054+
assert observed_states.type.shape == (*batch_size, 7, 1)
1055+
assert filtered_covariances.type.shape == (*batch_size, 7, 2, 2)
1056+
assert predicted_covariances.type.shape == (*batch_size, 7, 2, 2)
1057+
assert observed_covariances.type.shape == (*batch_size, 7, 1, 1)
1058+
1059+
ks = KalmanSmoother()
1060+
smoothed_states, smoothed_covariances = ks.build_graph(
1061+
T, R, Q, filtered_states, filtered_covariances
1062+
)
1063+
assert smoothed_states.type.shape == (
1064+
*batch_size,
1065+
None,
1066+
2,
1067+
) # TODO: why do we lose the time dimension here?
1068+
assert smoothed_covariances.type.shape == (
1069+
*batch_size,
1070+
None,
1071+
2,
1072+
2,
1073+
) # TODO: why do we lose the time dimension here?

0 commit comments

Comments
 (0)