Skip to content

Commit a91f959

Browse files
committed
Vectorize statespace matrix builder
1 parent c333b69 commit a91f959

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
VECTOR_VALUED,
4747
)
4848
from pymc_extras.statespace.utils.data_tools import register_data_with_pymc
49+
from pytensor.graph.replace import vectorize_graph
4950

5051
_log = logging.getLogger("pymc.experimental.statespace")
5152

@@ -726,7 +727,7 @@ def _insert_random_variables(self):
726727
matrices = list(self._unpack_statespace_with_placeholders())
727728

728729
replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()}
729-
self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)
730+
self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict)
730731

731732
def _insert_data_variables(self):
732733
"""

tests/statespace/test_statespace.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,7 @@ def test_invalid_scenarios():
710710
# Giving a list, tuple, or Series when a matrix of data is expected should always raise
711711
with pytest.raises(
712712
ValueError,
713-
match="Scenario data for variable 'a' has the wrong number of columns. "
714-
"Expected 2, got 1",
713+
match="Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1",
715714
):
716715
for data_type in [list, tuple, pd.Series]:
717716
ss_mod._validate_scenario_data(data_type(np.zeros(10)))
@@ -720,15 +719,14 @@ def test_invalid_scenarios():
720719
# Providing irrevelant data raises
721720
with pytest.raises(
722721
ValueError,
723-
match="Scenario data provided for variable 'jk lol', which is not an exogenous " "variable",
722+
match="Scenario data provided for variable 'jk lol', which is not an exogenous variable",
724723
):
725724
ss_mod._validate_scenario_data({"jk lol": np.zeros(10)})
726725

727726
# Incorrect 2nd dimension of a non-dataframe
728727
with pytest.raises(
729728
ValueError,
730-
match="Scenario data for variable 'a' has the wrong number of columns. Expected "
731-
"2, got 1",
729+
match="Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1",
732730
):
733731
scenario = np.zeros(10).tolist()
734732
ss_mod._validate_scenario_data(scenario)
@@ -870,3 +868,13 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
870868
regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])
871869

872870
assert_allclose(regression_effect, regression_effect_expected)
871+
872+
873+
@pytest.mark.parametrize("batch_size", [(10,), (10, 3, 5)])
874+
def test_insert_batched_rvs(ss_mod, batch_size):
875+
with pm.Model():
876+
rho = pm.Normal("rho", shape=batch_size)
877+
zeta = pm.Normal("zeta", shape=batch_size)
878+
ss_mod._insert_random_variables()
879+
matrices = ss_mod.unpack_statespace()
880+
assert matrices[4].type.shape == (*batch_size, 2, 2)

0 commit comments

Comments
 (0)