diff --git a/pymc_extras/statespace/models/structural/components/seasonality.py b/pymc_extras/statespace/models/structural/components/seasonality.py index 341bd6ae5..05e7d41a1 100644 --- a/pymc_extras/statespace/models/structural/components/seasonality.py +++ b/pymc_extras/statespace/models/structural/components/seasonality.py @@ -212,7 +212,7 @@ class TimeSeasonality(Component): sigma_level_trend = pm.HalfNormal( "sigma_level_trend", sigma=1e-6, dims=ss_mod.param_dims["sigma_level_trend"] ) - coefs_annual = pm.Normal("coefs_annual", sigma=1e-2, dims=ss_mod.param_dims["coefs_annual"]) + params_annual = pm.Normal("params_annual", sigma=1e-2, dims=ss_mod.param_dims["params_annual"]) ss_mod.build_statespace_graph(data) idata = pm.sample( @@ -298,10 +298,10 @@ def populate_component_properties(self): for endog_name in self.observed_state_names for state_name in self.provided_state_names ] - self.param_names = [f"coefs_{self.name}"] + self.param_names = [f"params_{self.name}"] self.param_info = { - f"coefs_{self.name}": { + f"params_{self.name}": { "shape": (k_states,) if k_endog == 1 else (k_endog, k_states), "constraints": None, "dims": (f"state_{self.name}",) @@ -311,7 +311,7 @@ def populate_component_properties(self): } self.param_dims = { - f"coefs_{self.name}": (f"state_{self.name}",) + f"params_{self.name}": (f"state_{self.name}",) if k_endog == 1 else (f"endog_{self.name}", f"state_{self.name}") } @@ -327,12 +327,14 @@ def populate_component_properties(self): if self.innovations: self.param_names += [f"sigma_{self.name}"] + self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names] self.param_info[f"sigma_{self.name}"] = { - "shape": (), + "shape": () if k_endog == 1 else (k_endog,), "constraints": "Positive", - "dims": None, + "dims": None if k_endog == 1 else (f"endog_{self.name}",), } - self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names] + if k_endog > 1: + self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",) def make_symbolic_graph(self) -> None: k_states = self.k_states // self.k_endog @@ -377,7 +379,7 @@ def make_symbolic_graph(self) -> None: self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)]) initial_states = self.make_and_register_variable( - f"coefs_{self.name}", + f"params_{self.name}", shape=(k_unique_states,) if k_endog == 1 else (k_endog, k_unique_states), ) if k_endog == 1: @@ -506,7 +508,7 @@ def make_symbolic_graph(self) -> None: self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)]) init_state = self.make_and_register_variable( - f"{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs) + f"params_{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs) ) init_state_idx = np.concatenate( @@ -535,19 +537,30 @@ def make_symbolic_graph(self) -> None: def populate_component_properties(self): k_endog = self.k_endog n_coefs = self.n_coefs - k_states = self.k_states // k_endog self.state_names = [ - f"{f}_{self.name}_{i}[{obs_state_name}]" + f"{f}_{i}_{self.name}[{obs_state_name}]" for obs_state_name in self.observed_state_names for i in range(self.n) for f in ["Cos", "Sin"] ] - self.param_names = [f"{self.name}"] + # determine which state names correspond to parameters + # all endog variables use same state structure, so we just need + # the first n_coefs state names (which may be less than total if saturated) + param_state_names = [f"{f}_{i}_{self.name}" for i in range(self.n) for f in ["Cos", "Sin"]][ + :n_coefs + ] + + self.param_names = [f"params_{self.name}"] + + self.param_dims = { + f"params_{self.name}": (f"state_{self.name}",) + if k_endog == 1 + else (f"endog_{self.name}", f"state_{self.name}") + } - self.param_dims = {self.name: (f"state_{self.name}",)} self.param_info = { - f"{self.name}": { + f"params_{self.name}": { "shape": (n_coefs,) if k_endog == 1 else (k_endog, n_coefs), "constraints": None, "dims": (f"state_{self.name}",) @@ -556,23 +569,22 @@ def populate_component_properties(self): } } - # Regardless of whether the fourier basis are saturated, there will always be one symbolic state per basis. - # That's why the self.states is just a simple loop over everything. But when saturated, one of those states - # doesn't have an associated **parameter**, so the coords need to be adjusted to reflect this. - init_state_idx = np.concatenate( - [ - np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs] - for i in range(k_endog) - ], - axis=0, + self.coords = ( + {f"state_{self.name}": param_state_names} + if k_endog == 1 + else { + f"endog_{self.name}": self.observed_state_names, + f"state_{self.name}": param_state_names, + } ) - self.coords = {f"state_{self.name}": [self.state_names[i] for i in init_state_idx]} if self.innovations: - self.shock_names = self.state_names.copy() self.param_names += [f"sigma_{self.name}"] + self.shock_names = self.state_names.copy() self.param_info[f"sigma_{self.name}"] = { - "shape": () if k_endog == 1 else (k_endog, n_coefs), + "shape": () if k_endog == 1 else (k_endog,), "constraints": "Positive", "dims": None if k_endog == 1 else (f"endog_{self.name}",), } + if k_endog > 1: + self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",) diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index 2a0717f0b..95bc957cd 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -120,7 +120,7 @@ class StructuralTimeSeries(PyMCStateSpace): initial_trend = pm.Normal('initial_trend', sigma=10, dims=ss_mod.param_dims['initial_trend']) sigma_trend = pm.HalfNormal('sigma_trend', sigma=1, dims=ss_mod.param_dims['sigma_trend']) - seasonal_coefs = pm.Normal('seasonal_coefs', sigma=1, dims=ss_mod.param_dims['seasonal_coefs']) + seasonal_coefs = pm.Normal('params_seasonal', sigma=1, dims=ss_mod.param_dims['params_seasonal']) sigma_seasonal = pm.HalfNormal('sigma_seasonal', sigma=1) sigma_obs = pm.Exponential('sigma_obs', 1, dims=ss_mod.param_dims['sigma_obs']) diff --git a/tests/statespace/models/structural/components/test_seasonality.py b/tests/statespace/models/structural/components/test_seasonality.py index 162b25868..d439e0bb2 100644 --- a/tests/statespace/models/structural/components/test_seasonality.py +++ b/tests/statespace/models/structural/components/test_seasonality.py @@ -6,6 +6,7 @@ from pytensor.graph.basic import explicit_graph_inputs from pymc_extras.statespace.models import structural as st +from pymc_extras.statespace.models.structural.components.seasonality import FrequencySeasonality from tests.statespace.models.structural.conftest import _assert_basic_coords_correct from tests.statespace.test_utilities import assert_pattern_repeats, simulate_from_numpy_model @@ -38,7 +39,7 @@ def random_word(rng): x0 = np.zeros(mod.k_states // mod.duration, dtype=config.floatX) x0[0] = 1 - params = {"coefs_season": x0} + params = {"params_season": x0} if innovations: params["sigma_season"] = 0.0 @@ -84,7 +85,7 @@ def test_time_seasonality_multiple_observed(rng, d, remove_first_state): x0[0, 0] = 1 x0[1, 0] = 2.0 - params = {"coefs_season": x0, "sigma_season": np.array([0.0, 0.0], dtype=config.floatX)} + params = {"params_season": x0, "sigma_season": np.array([0.0, 0.0], dtype=config.floatX)} x, y = simulate_from_numpy_model(mod, rng, params, steps=123 * d) assert_pattern_repeats(y[:, 0], s * d, atol=ATOL, rtol=RTOL) @@ -169,8 +170,8 @@ def test_add_two_time_seasonality_different_observed(rng, d1, d2): mod = (mod1 + mod2).build(verbose=False) params = { - "coefs_season1": np.array([1.0, 0.0, 0.0], dtype=config.floatX), - "coefs_season2": np.array([3.0, 0.0, 0.0, 0.0], dtype=config.floatX), + "params_season1": np.array([1.0, 0.0, 0.0], dtype=config.floatX), + "params_season2": np.array([3.0, 0.0, 0.0, 0.0], dtype=config.floatX), "sigma_season1": np.array(0.0, dtype=config.floatX), "sigma_season2": np.array(0.0, dtype=config.floatX), "initial_state_cov": np.eye(mod.k_states, dtype=config.floatX), @@ -205,8 +206,8 @@ def test_add_two_time_seasonality_different_observed(rng, d1, d2): ) x0, T = fn( - coefs_season1=np.array([1.0, 0.0, 0.0], dtype=config.floatX), - coefs_season2=np.array([3.0, 0.0, 0.0, 1.2], dtype=config.floatX), + params_season1=np.array([1.0, 0.0, 0.0], dtype=config.floatX), + params_season2=np.array([3.0, 0.0, 0.0, 1.2], dtype=config.floatX), ) np.testing.assert_allclose( @@ -242,22 +243,26 @@ def get_shift_factor(s): @pytest.mark.parametrize("s", [5, 10, 25, 25.2]) def test_frequency_seasonality(n, s, rng): mod = st.FrequencySeasonality(season_length=s, n=n, name="season") + assert mod.param_info["sigma_season"]["shape"] == () # scalar for univariate + assert mod.param_info["sigma_season"]["dims"] is None + assert len(mod.coords["state_season"]) == mod.n_coefs + x0 = rng.normal(size=mod.n_coefs).astype(config.floatX) - params = {"season": x0, "sigma_season": 0.0} + params = {"params_season": x0, "sigma_season": 0.0} k = get_shift_factor(s) T = int(s * k) x, y = simulate_from_numpy_model(mod, rng, params, steps=2 * T) assert_pattern_repeats(y, T, atol=ATOL, rtol=RTOL) - # Check coords + # check coords mod = mod.build(verbose=False) _assert_basic_coords_correct(mod) if n is None: n = int(s // 2) - states = [f"{f}_season_{i}" for i in range(n) for f in ["Cos", "Sin"]] + states = [f"{f}_{i}_season" for i in range(n) for f in ["Cos", "Sin"]] - # Remove the last state when the model is completely saturated + # remove last state when model is completely saturated if s / n == 2.0: states.pop() assert mod.coords["state_season"] == states @@ -273,47 +278,47 @@ def test_frequency_seasonality_multiple_observed(rng): innovations=True, observed_state_names=observed_state_names, ) + assert mod.param_info["params_season"]["shape"] == (mod.k_endog, mod.n_coefs) + assert mod.param_info["params_season"]["dims"] == ("endog_season", "state_season") + assert mod.param_dims["sigma_season"] == ("endog_season",) + expected_state_names = [ - "Cos_season_0[data_1]", - "Sin_season_0[data_1]", - "Cos_season_1[data_1]", - "Sin_season_1[data_1]", - "Cos_season_0[data_2]", - "Sin_season_0[data_2]", - "Cos_season_1[data_2]", - "Sin_season_1[data_2]", + "Cos_0_season[data_1]", + "Sin_0_season[data_1]", + "Cos_1_season[data_1]", + "Sin_1_season[data_1]", + "Cos_0_season[data_2]", + "Sin_0_season[data_2]", + "Cos_1_season[data_2]", + "Sin_1_season[data_2]", ] assert mod.state_names == expected_state_names assert mod.shock_names == [ - "Cos_season_0[data_1]", - "Sin_season_0[data_1]", - "Cos_season_1[data_1]", - "Sin_season_1[data_1]", - "Cos_season_0[data_2]", - "Sin_season_0[data_2]", - "Cos_season_1[data_2]", - "Sin_season_1[data_2]", + "Cos_0_season[data_1]", + "Sin_0_season[data_1]", + "Cos_1_season[data_1]", + "Sin_1_season[data_1]", + "Cos_0_season[data_2]", + "Sin_0_season[data_2]", + "Cos_1_season[data_2]", + "Sin_1_season[data_2]", ] - # Simulate x0 = np.zeros((2, 3), dtype=config.floatX) x0[0, 0] = 1.0 x0[1, 0] = 2.0 - params = {"season": x0, "sigma_season": np.zeros(2, dtype=config.floatX)} + params = {"params_season": x0, "sigma_season": np.zeros(2, dtype=config.floatX)} x, y = simulate_from_numpy_model(mod, rng, params, steps=12) - # Check periodicity for each observed series + # check periodicity for each observed series assert_pattern_repeats(y[:, 0], 4, atol=ATOL, rtol=RTOL) assert_pattern_repeats(y[:, 1], 4, atol=ATOL, rtol=RTOL) mod = mod.build(verbose=False) assert list(mod.coords["state_season"]) == [ - "Cos_season_0[data_1]", - "Sin_season_0[data_1]", - "Cos_season_1[data_1]", - "Cos_season_0[data_2]", - "Sin_season_0[data_2]", - "Cos_season_1[data_2]", + "Cos_0_season", + "Sin_0_season", + "Cos_1_season", ] x0_sym, *_, T_sym, Z_sym, R_sym, _, Q_sym = mod._unpack_statespace_with_placeholders() @@ -386,8 +391,8 @@ def test_add_two_frequency_seasonality_different_observed(rng): mod = (mod1 + mod2).build(verbose=False) params = { - "freq1": np.array([1.0, 0.0, 0.0], dtype=config.floatX), - "freq2": np.array([3.0, 0.0], dtype=config.floatX), + "params_freq1": np.array([1.0, 0.0, 0.0], dtype=config.floatX), + "params_freq2": np.array([3.0, 0.0], dtype=config.floatX), "sigma_freq1": np.array(0.0, dtype=config.floatX), "sigma_freq2": np.array(0.0, dtype=config.floatX), "initial_state_cov": np.eye(mod.k_states, dtype=config.floatX), @@ -399,21 +404,21 @@ def test_add_two_frequency_seasonality_different_observed(rng): assert_pattern_repeats(y[:, 1], 6, atol=ATOL, rtol=RTOL) assert mod.state_names == [ - "Cos_freq1_0[data_1]", - "Sin_freq1_0[data_1]", - "Cos_freq1_1[data_1]", - "Sin_freq1_1[data_1]", - "Cos_freq2_0[data_2]", - "Sin_freq2_0[data_2]", + "Cos_0_freq1[data_1]", + "Sin_0_freq1[data_1]", + "Cos_1_freq1[data_1]", + "Sin_1_freq1[data_1]", + "Cos_0_freq2[data_2]", + "Sin_0_freq2[data_2]", ] assert mod.shock_names == [ - "Cos_freq1_0[data_1]", - "Sin_freq1_0[data_1]", - "Cos_freq1_1[data_1]", - "Sin_freq1_1[data_1]", - "Cos_freq2_0[data_2]", - "Sin_freq2_0[data_2]", + "Cos_0_freq1[data_1]", + "Sin_0_freq1[data_1]", + "Cos_1_freq1[data_1]", + "Sin_1_freq1[data_1]", + "Cos_0_freq2[data_2]", + "Sin_0_freq2[data_2]", ] x0, *_, T = mod._unpack_statespace_with_placeholders()[:5] @@ -425,8 +430,8 @@ def test_add_two_frequency_seasonality_different_observed(rng): ) x0_v, T_v = fn( - freq1=np.array([1.0, 0.0, 1.2], dtype=config.floatX), - freq2=np.array([3.0, 0.0], dtype=config.floatX), + params_freq1=np.array([1.0, 0.0, 1.2], dtype=config.floatX), + params_freq2=np.array([3.0, 0.0], dtype=config.floatX), ) # Make sure the extra 0 in from the first component (the saturated state) is there! @@ -452,3 +457,89 @@ def test_add_two_frequency_seasonality_different_observed(rng): expected_T[4:6, 4:6] = freq2_T np.testing.assert_allclose(expected_T, T_v, atol=ATOL, rtol=RTOL) + + +@pytest.mark.parametrize( + "test_case", + [ + { + "name": "single_endog_non_saturated", + "season_length": 12, + "n": 2, + "observed_state_names": ["data1"], + "expected_shape": (4,), + }, + { + "name": "single_endog_saturated", + "season_length": 12, + "n": 6, + "observed_state_names": ["data1"], + "expected_shape": (11,), + }, + { + "name": "multiple_endog_non_saturated", + "season_length": 12, + "n": 2, + "observed_state_names": ["data1", "data2"], + "expected_shape": (2, 4), + }, + { + "name": "multiple_endog_saturated", + "season_length": 12, + "n": 6, + "observed_state_names": ["data1", "data2"], + "expected_shape": (2, 11), + }, + { + "name": "small_n", + "season_length": 12, + "n": 1, + "observed_state_names": ["data1"], + "expected_shape": (2,), + }, + { + "name": "many_endog", + "season_length": 12, + "n": 2, + "observed_state_names": ["data1", "data2", "data3", "data4"], + "expected_shape": (4, 4), + }, + ], + ids=lambda x: x["name"], +) +def test_frequency_seasonality_coordinates(test_case): + model_name = f"season_{test_case['name'].split('_')[0]}" + + season = FrequencySeasonality( + season_length=test_case["season_length"], + n=test_case["n"], + name=model_name, + observed_state_names=test_case["observed_state_names"], + ) + season.populate_component_properties() + + # assert parameter shape + assert season.param_info[f"params_{model_name}"]["shape"] == test_case["expected_shape"] + + # generate expected state names based on actual model name + expected_state_names = [ + f"{f}_{i}_{model_name}" for i in range(test_case["n"]) for f in ["Cos", "Sin"] + ][: test_case["expected_shape"][-1]] + + # assert coordinate structure + if len(test_case["observed_state_names"]) == 1: + assert len(season.coords[f"state_{model_name}"]) == test_case["expected_shape"][0] + assert season.coords[f"state_{model_name}"] == expected_state_names + else: + assert len(season.coords[f"endog_{model_name}"]) == test_case["expected_shape"][0] + assert len(season.coords[f"state_{model_name}"]) == test_case["expected_shape"][1] + assert season.coords[f"state_{model_name}"] == expected_state_names + + # Check coords match the expected shape + param_shape = season.param_info[f"params_{model_name}"]["shape"] + state_coords = season.coords[f"state_{model_name}"] + endog_coords = season.coords.get(f"endog_{model_name}") + + assert len(state_coords) == param_shape[-1] + if endog_coords: + assert len(endog_coords) == param_shape[0] diff --git a/tests/statespace/models/structural/test_against_statsmodels.py b/tests/statespace/models/structural/test_against_statsmodels.py index 833d540bc..31b4cbae6 100644 --- a/tests/statespace/models/structural/test_against_statsmodels.py +++ b/tests/statespace/models/structural/test_against_statsmodels.py @@ -104,12 +104,12 @@ def _assert_keys_match(test_dict, expected_dict): expected_keys = list(expected_dict.keys()) param_keys = list(test_dict.keys()) key_diff = set(expected_keys) - set(param_keys) - assert len(key_diff) == 0, f'{", ".join(key_diff)} were not found in the test_dict keys.' + assert len(key_diff) == 0, f"{', '.join(key_diff)} were not found in the test_dict keys." key_diff = set(param_keys) - set(expected_keys) assert ( len(key_diff) == 0 - ), f'{", ".join(key_diff)} were keys of the tests_dict not in expected_dict.' + ), f"{', '.join(key_diff)} were keys of the tests_dict not in expected_dict." def _assert_param_dims_correct(param_dims, expected_dims): @@ -296,8 +296,8 @@ def create_structural_model_and_equivalent_statsmodel( if seasonal is not None: state_names = [f"seasonal_{i}" for i in range(seasonal)][1:] seasonal_coefs = rng.normal(size=(seasonal - 1,)).astype(floatX) - params["coefs_seasonal"] = seasonal_coefs - expected_param_dims["coefs_seasonal"] += ("state_seasonal",) + params["params_seasonal"] = seasonal_coefs + expected_param_dims["params_seasonal"] += ("state_seasonal",) expected_coords["state_seasonal"] += tuple(state_names) expected_coords[ALL_STATE_DIM] += state_names @@ -331,12 +331,12 @@ def create_structural_model_and_equivalent_statsmodel( s = d["period"] last_state_not_identified = (s / n) == 2.0 n_states = 2 * n - int(last_state_not_identified) - state_names = [f"{f}_seasonal_{s}_{i}" for i in range(n) for f in ["Cos", "Sin"]] + state_names = [f"{f}_{i}_seasonal_{s}" for i in range(n) for f in ["Cos", "Sin"]] seasonal_params = rng.normal(size=n_states).astype(floatX) - params[f"seasonal_{s}"] = seasonal_params - expected_param_dims[f"seasonal_{s}"] += (f"state_seasonal_{s}",) + params[f"params_seasonal_{s}"] = seasonal_params + expected_param_dims[f"params_seasonal_{s}"] += (f"state_seasonal_{s}",) expected_coords[ALL_STATE_DIM] += state_names expected_coords[ALL_STATE_AUX_DIM] += state_names expected_coords[f"state_seasonal_{s}"] += ( @@ -404,7 +404,7 @@ def create_structural_model_and_equivalent_statsmodel( components.append(comp) if autoregressive is not None: - ar_names = [f"L{i+1}" for i in range(autoregressive)] + ar_names = [f"L{i + 1}" for i in range(autoregressive)] params_ar = rng.normal(size=(autoregressive,)).astype(floatX) if autoregressive == 1: params_ar = params_ar.item() @@ -421,8 +421,8 @@ def create_structural_model_and_equivalent_statsmodel( sm_params["sigma2.ar"] = sigma2 for i, rho in enumerate(params_ar): - sm_init[f"ar.L{i+1}"] = 0 - sm_params[f"ar.L{i+1}"] = rho + sm_init[f"ar.L{i + 1}"] = 0 + sm_params[f"ar.L{i + 1}"] = rho comp = st.AutoregressiveComponent(name="ar", order=autoregressive) components.append(comp) @@ -439,7 +439,7 @@ def create_structural_model_and_equivalent_statsmodel( for i, beta in enumerate(betas): sm_params[f"beta.x{i + 1}"] = beta - sm_init[f"beta.x{i+1}"] = beta + sm_init[f"beta.x{i + 1}"] = beta comp = st.RegressionComponent(name="exog", state_names=names) components.append(comp) diff --git a/tests/statespace/models/structural/test_core.py b/tests/statespace/models/structural/test_core.py index bd9dcb032..a11d4e987 100644 --- a/tests/statespace/models/structural/test_core.py +++ b/tests/statespace/models/structural/test_core.py @@ -28,7 +28,7 @@ def test_add_components(): "sigma_level_trend": np.ones(2, dtype=floatX), } se_params = { - "coefs_seasonal": np.ones(11, dtype=floatX), + "params_seasonal": np.ones(11, dtype=floatX), "sigma_seasonal": 1.0, } all_params = ll_params.copy() @@ -97,7 +97,7 @@ def test_extract_components_from_idata(rng): beta_exog = pm.Normal("beta_exog", dims=["state_exog"]) initial_trend = pm.Normal("initial_level_trend", dims=["state_level_trend"]) sigma_trend = pm.Exponential("sigma_level_trend", 1, dims=["shock_level_trend"]) - seasonal_coefs = pm.Normal("seasonal", dims=["state_seasonal"]) + seasonal_coefs = pm.Normal("params_seasonal", dims=["state_seasonal"]) sigma_obs = pm.Exponential("sigma_obs", 1) mod.build_statespace_graph(y) @@ -144,7 +144,7 @@ def test_extract_multiple_observed(rng): sigma_auto_regressive = pm.Normal("sigma_auto_regressive", dims=["endog_auto_regressive"]) initial_trend = pm.Normal("initial_trend", dims=["endog_trend", "state_trend"]) sigma_trend = pm.Exponential("sigma_trend", 1, dims=["endog_trend", "shock_trend"]) - seasonal_coefs = pm.Normal("seasonal", dims=["state_seasonal"]) + seasonal_coefs = pm.Normal("params_seasonal", dims=["state_seasonal"]) sigma_obs = pm.Exponential("sigma_obs", 1, dims=["endog_obs"]) mod.build_statespace_graph(y)