Skip to content

Commit 9b9418d

Browse files
committed
Vectorize statespace matrix builder
1 parent 51d5f5a commit 9b9418d

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
VECTOR_VALUED,
4747
)
4848
from pymc_extras.statespace.utils.data_tools import register_data_with_pymc
49+
from pytensor.graph.replace import vectorize_graph
4950

5051
_log = logging.getLogger("pymc.experimental.statespace")
5152

@@ -734,7 +735,7 @@ def _insert_random_variables(self):
734735
matrices = list(self._unpack_statespace_with_placeholders())
735736

736737
replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()}
737-
self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)
738+
self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict)
738739

739740
def _insert_data_variables(self):
740741
"""

tests/statespace/core/test_statespace.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -733,8 +733,7 @@ def test_invalid_scenarios():
733733
# Giving a list, tuple, or Series when a matrix of data is expected should always raise
734734
with pytest.raises(
735735
ValueError,
736-
match="Scenario data for variable 'a' has the wrong number of columns. "
737-
"Expected 2, got 1",
736+
match="Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1",
738737
):
739738
for data_type in [list, tuple, pd.Series]:
740739
ss_mod._validate_scenario_data(data_type(np.zeros(10)))
@@ -743,15 +742,14 @@ def test_invalid_scenarios():
743742
# Providing irrevelant data raises
744743
with pytest.raises(
745744
ValueError,
746-
match="Scenario data provided for variable 'jk lol', which is not an exogenous " "variable",
745+
match="Scenario data provided for variable 'jk lol', which is not an exogenous variable",
747746
):
748747
ss_mod._validate_scenario_data({"jk lol": np.zeros(10)})
749748

750749
# Incorrect 2nd dimension of a non-dataframe
751750
with pytest.raises(
752751
ValueError,
753-
match="Scenario data for variable 'a' has the wrong number of columns. Expected "
754-
"2, got 1",
752+
match="Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1",
755753
):
756754
scenario = np.zeros(10).tolist()
757755
ss_mod._validate_scenario_data(scenario)
@@ -1017,3 +1015,13 @@ def test_foreacast_valid_index(exog_pymc_mod, exog_ss_mod, exog_data):
10171015

10181016
assert forecasts.forecast_latent.shape[2] == n_periods
10191017
assert forecasts.forecast_observed.shape[2] == n_periods
1018+
1019+
1020+
@pytest.mark.parametrize("batch_size", [(10,), (10, 3, 5)])
1021+
def test_insert_batched_rvs(ss_mod, batch_size):
1022+
with pm.Model():
1023+
rho = pm.Normal("rho", shape=batch_size)
1024+
zeta = pm.Normal("zeta", shape=batch_size)
1025+
ss_mod._insert_random_variables()
1026+
matrices = ss_mod.unpack_statespace()
1027+
assert matrices[4].type.shape == (*batch_size, 2, 2)

0 commit comments

Comments
 (0)