diff --git a/pymc_extras/statespace/models/structural.py b/pymc_extras/statespace/models/structural.py index bc61eab98..40d1dedff 100644 --- a/pymc_extras/statespace/models/structural.py +++ b/pymc_extras/statespace/models/structural.py @@ -1071,6 +1071,12 @@ class TimeSeasonality(Component): If None, states will be numbered ``[State_0, ..., State_s]`` + remove_first_state: bool, default True + If True, the first state will be removed from the model. This is done because there are only n-1 degrees of + freedom in the seasonal component, and one state is not identified. If False, the first state will be + included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with + ZeroSumNormal). + Notes ----- A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to @@ -1163,7 +1169,7 @@ def __init__( innovations: bool = True, name: str | None = None, state_names: list | None = None, - pop_state: bool = True, + remove_first_state: bool = True, ): if name is None: name = f"Seasonal[s={season_length}]" @@ -1176,14 +1182,15 @@ def __init__( ) state_names = state_names.copy() self.innovations = innovations - self.pop_state = pop_state + self.remove_first_state = remove_first_state - if self.pop_state: + if self.remove_first_state: # In traditional models, the first state isn't identified, so we can help out the user by automatically # discarding it. # TODO: Can this be stashed and reconstructed automatically somehow? state_names.pop(0) - k_states = season_length - 1 + + k_states = season_length - int(self.remove_first_state) super().__init__( name=name, @@ -1218,8 +1225,16 @@ def populate_component_properties(self): self.shock_names = [f"{self.name}"] def make_symbolic_graph(self) -> None: - T = np.eye(self.k_states, k=-1) - T[0, :] = -1 + if self.remove_first_state: + # In this case, parameters are normalized to sum to zero, so the current state is the negative sum of + # all previous states. + T = np.eye(self.k_states, k=-1) + T[0, :] = -1 + else: + # In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a + # circulant matrix that cycles between the states. + T = np.eye(self.k_states, k=1) + T[-1, 0] = 1 self.ssm["transition", :, :] = T self.ssm["design", 0, 0] = 1 diff --git a/tests/statespace/test_structural.py b/tests/statespace/test_structural.py index 858efadfb..c398c723e 100644 --- a/tests/statespace/test_structural.py +++ b/tests/statespace/test_structural.py @@ -2,6 +2,7 @@ import warnings from collections import defaultdict +from copyreg import remove_extension from typing import Optional import numpy as np @@ -592,13 +593,18 @@ def test_autoregressive_model(order, rng): @pytest.mark.parametrize("s", [10, 25, 50]) @pytest.mark.parametrize("innovations", [True, False]) -def test_time_seasonality(s, innovations, rng): +@pytest.mark.parametrize("remove_first_state", [True, False]) +def test_time_seasonality(s, innovations, remove_first_state, rng): def random_word(rng): return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5)) state_names = [random_word(rng) for _ in range(s)] mod = st.TimeSeasonality( - season_length=s, innovations=innovations, name="season", state_names=state_names + season_length=s, + innovations=innovations, + name="season", + state_names=state_names, + remove_first_state=remove_first_state, ) x0 = np.zeros(mod.k_states, dtype=floatX) x0[0] = 1 @@ -615,7 +621,8 @@ def random_word(rng): # Check coords mod.build(verbose=False) _assert_basic_coords_correct(mod) - assert mod.coords["season_state"] == state_names[1:] + test_slice = slice(1, None) if remove_first_state else slice(None) + assert mod.coords["season_state"] == state_names[test_slice] def get_shift_factor(s):