Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions pymc_extras/statespace/models/structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,12 @@ class TimeSeasonality(Component):

If None, states will be numbered ``[State_0, ..., State_s]``

remove_first_state: bool, default True
If True, the first state will be removed from the model. This is done because there are only n-1 degrees of
freedom in the seasonal component, and one state is not identified. If False, the first state will be
included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with
ZeroSumNormal).

Notes
-----
A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
Expand Down Expand Up @@ -1163,7 +1169,7 @@ def __init__(
innovations: bool = True,
name: str | None = None,
state_names: list | None = None,
pop_state: bool = True,
remove_first_state: bool = True,
):
if name is None:
name = f"Seasonal[s={season_length}]"
Expand All @@ -1176,14 +1182,15 @@ def __init__(
)
state_names = state_names.copy()
self.innovations = innovations
self.pop_state = pop_state
self.remove_first_state = remove_first_state

if self.pop_state:
if self.remove_first_state:
# In traditional models, the first state isn't identified, so we can help out the user by automatically
# discarding it.
# TODO: Can this be stashed and reconstructed automatically somehow?
state_names.pop(0)
k_states = season_length - 1

k_states = season_length - int(self.remove_first_state)

super().__init__(
name=name,
Expand Down Expand Up @@ -1218,8 +1225,16 @@ def populate_component_properties(self):
self.shock_names = [f"{self.name}"]

def make_symbolic_graph(self) -> None:
T = np.eye(self.k_states, k=-1)
T[0, :] = -1
if self.remove_first_state:
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
# all previous states.
T = np.eye(self.k_states, k=-1)
T[0, :] = -1
else:
# In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
# circulant matrix that cycles between the states.
T = np.eye(self.k_states, k=1)
T[-1, 0] = 1

self.ssm["transition", :, :] = T
self.ssm["design", 0, 0] = 1
Expand Down
13 changes: 10 additions & 3 deletions tests/statespace/test_structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings

from collections import defaultdict
from copyreg import remove_extension
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -592,13 +593,18 @@ def test_autoregressive_model(order, rng):

@pytest.mark.parametrize("s", [10, 25, 50])
@pytest.mark.parametrize("innovations", [True, False])
def test_time_seasonality(s, innovations, rng):
@pytest.mark.parametrize("remove_first_state", [True, False])
def test_time_seasonality(s, innovations, remove_first_state, rng):
def random_word(rng):
return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))

state_names = [random_word(rng) for _ in range(s)]
mod = st.TimeSeasonality(
season_length=s, innovations=innovations, name="season", state_names=state_names
season_length=s,
innovations=innovations,
name="season",
state_names=state_names,
remove_first_state=remove_first_state,
)
x0 = np.zeros(mod.k_states, dtype=floatX)
x0[0] = 1
Expand All @@ -615,7 +621,8 @@ def random_word(rng):
# Check coords
mod.build(verbose=False)
_assert_basic_coords_correct(mod)
assert mod.coords["season_state"] == state_names[1:]
test_slice = slice(1, None) if remove_first_state else slice(None)
assert mod.coords["season_state"] == state_names[test_slice]


def get_shift_factor(s):
Expand Down
Loading