Skip to content

Commit f5054a3

Browse files
Add shared_state argument to TimeSeasonality
1 parent dfcfbe9 commit f5054a3

File tree

2 files changed

+152
-25
lines changed

2 files changed

+152
-25
lines changed

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ class TimeSeasonality(Component):
4444
observed_state_names: list[str] | None, default None
4545
List of strings for observed state labels. If None, defaults to ["data"].
4646
47+
share_states: bool, default False
48+
Whether latent states are shared across the observed states. If True, there will be only one set of latent
49+
states, which are observed by all observed states. If False, each observed state has its own set of
50+
latent states. This argument has no effect if `k_endog` is 1.
51+
4752
Notes
4853
-----
4954
A seasonal effect is any pattern that repeats at fixed intervals. There are several ways to model such effects;
@@ -235,6 +240,7 @@ def __init__(
235240
state_names: list | None = None,
236241
remove_first_state: bool = True,
237242
observed_state_names: list[str] | None = None,
243+
share_states: bool = False,
238244
):
239245
if observed_state_names is None:
240246
observed_state_names = ["data"]
@@ -261,6 +267,7 @@ def __init__(
261267
)
262268
state_names = state_names.copy()
263269

270+
self.share_states = share_states
264271
self.innovations = innovations
265272
self.duration = duration
266273
self.remove_first_state = remove_first_state
@@ -281,44 +288,54 @@ def __init__(
281288
super().__init__(
282289
name=name,
283290
k_endog=k_endog,
284-
k_states=k_states * k_endog,
285-
k_posdef=k_posdef * k_endog,
291+
k_states=k_states if share_states else k_states * k_endog,
292+
k_posdef=k_posdef if share_states else k_posdef * k_endog,
286293
observed_state_names=observed_state_names,
287294
measurement_error=False,
288295
combine_hidden_states=True,
289-
obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog),
296+
obs_state_idxs=np.tile(
297+
np.array([1.0] + [0.0] * (k_states - 1)), 1 if share_states else k_endog
298+
),
290299
)
291300

292301
def populate_component_properties(self):
293-
k_states = self.k_states // self.k_endog
294302
k_endog = self.k_endog
303+
k_endog_effective = 1 if self.share_states else k_endog
304+
305+
k_states = self.k_states // k_endog_effective
306+
307+
if self.share_states:
308+
self.state_names = [
309+
f"{state_name}[{self.name}_shared]" for state_name in self.provided_state_names
310+
]
311+
else:
312+
self.state_names = [
313+
f"{state_name}[{endog_name}]"
314+
for endog_name in self.observed_state_names
315+
for state_name in self.provided_state_names
316+
]
295317

296-
self.state_names = [
297-
f"{state_name}[{endog_name}]"
298-
for endog_name in self.observed_state_names
299-
for state_name in self.provided_state_names
300-
]
301318
self.param_names = [f"params_{self.name}"]
302319

303320
self.param_info = {
304321
f"params_{self.name}": {
305322
"shape": (k_states,) if k_endog == 1 else (k_endog, k_states),
306323
"constraints": None,
307324
"dims": (f"state_{self.name}",)
308-
if k_endog == 1
325+
if k_endog_effective == 1
309326
else (f"endog_{self.name}", f"state_{self.name}"),
310327
}
311328
}
312329

313330
self.param_dims = {
314331
f"params_{self.name}": (f"state_{self.name}",)
315-
if k_endog == 1
332+
if k_endog_effective == 1
316333
else (f"endog_{self.name}", f"state_{self.name}")
317334
}
318335

319336
self.coords = (
320337
{f"state_{self.name}": self.provided_state_names}
321-
if k_endog == 1
338+
if k_endog_effective == 1
322339
else {
323340
f"endog_{self.name}": self.observed_state_names,
324341
f"state_{self.name}": self.provided_state_names,
@@ -327,21 +344,27 @@ def populate_component_properties(self):
327344

328345
if self.innovations:
329346
self.param_names += [f"sigma_{self.name}"]
330-
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
331347
self.param_info[f"sigma_{self.name}"] = {
332-
"shape": () if k_endog == 1 else (k_endog,),
348+
"shape": () if k_endog_effective == 1 else (k_endog,),
333349
"constraints": "Positive",
334-
"dims": None if k_endog == 1 else (f"endog_{self.name}",),
350+
"dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
335351
}
352+
if self.share_states:
353+
self.shock_names = [f"{self.name}[shared]"]
354+
else:
355+
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
356+
336357
if k_endog > 1:
337358
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
338359

339360
def make_symbolic_graph(self) -> None:
340-
k_states = self.k_states // self.k_endog
361+
k_endog = self.k_endog
362+
k_endog_effective = 1 if self.share_states else k_endog
363+
k_states = self.k_states // k_endog_effective
341364
duration = self.duration
365+
342366
k_unique_states = k_states // duration
343-
k_posdef = self.k_posdef // self.k_endog
344-
k_endog = self.k_endog
367+
k_posdef = self.k_posdef // k_endog_effective
345368

346369
if self.remove_first_state:
347370
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
@@ -373,16 +396,18 @@ def make_symbolic_graph(self) -> None:
373396
T = pt.eye(k_states, k=1)
374397
T = pt.set_subtensor(T[-1, 0], 1)
375398

376-
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
399+
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog_effective)])
377400

378401
Z = pt.zeros((1, k_states))[0, 0].set(1)
379-
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
402+
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog_effective)])
380403

381404
initial_states = self.make_and_register_variable(
382405
f"params_{self.name}",
383-
shape=(k_unique_states,) if k_endog == 1 else (k_endog, k_unique_states),
406+
shape=(k_unique_states,)
407+
if k_endog_effective == 1
408+
else (k_endog_effective, k_unique_states),
384409
)
385-
if k_endog == 1:
410+
if k_endog_effective == 1:
386411
self.ssm["initial_state", :] = pt.extra_ops.repeat(initial_states, duration, axis=0)
387412
else:
388413
self.ssm["initial_state", :] = pt.extra_ops.repeat(
@@ -391,11 +416,11 @@ def make_symbolic_graph(self) -> None:
391416

392417
if self.innovations:
393418
R = pt.zeros((k_states, k_posdef))[0, 0].set(1.0)
394-
self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog)])
419+
self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog_effective)])
395420
season_sigma = self.make_and_register_variable(
396-
f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
421+
f"sigma_{self.name}", shape=() if k_endog_effective == 1 else (k_endog_effective,)
397422
)
398-
cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog))
423+
cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog_effective))
399424
self.ssm[cov_idx] = season_sigma**2
400425

401426

tests/statespace/models/structural/components/test_seasonality.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,108 @@ def test_time_seasonality_multiple_observed(rng, d, remove_first_state):
147147
np.testing.assert_allclose(matrix, expected)
148148

149149

150+
def test_time_seasonality_shared_states():
151+
mod = st.TimeSeasonality(
152+
season_length=3,
153+
duration=1,
154+
innovations=True,
155+
name="season",
156+
state_names=["season_1", "season_2", "season_3"],
157+
observed_state_names=["data_1", "data_2"],
158+
remove_first_state=False,
159+
share_states=True,
160+
)
161+
162+
assert mod.k_endog == 2
163+
assert mod.k_states == 3
164+
assert mod.k_posdef == 1
165+
166+
assert mod.coords["state_season"] == ["season_1", "season_2", "season_3"]
167+
168+
assert mod.state_names == [
169+
"season_1[season_shared]",
170+
"season_2[season_shared]",
171+
"season_3[season_shared]",
172+
]
173+
assert mod.shock_names == ["season[shared]"]
174+
175+
Z, T, R = pytensor.function(
176+
[], [mod.ssm["design"], mod.ssm["transition"], mod.ssm["selection"]], mode="FAST_COMPILE"
177+
)()
178+
179+
np.testing.assert_allclose(np.array([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), Z)
180+
181+
np.testing.assert_allclose(np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]), T)
182+
183+
np.testing.assert_allclose(np.array([[1.0], [0.0], [0.0]]), R)
184+
185+
186+
def test_add_mixed_shared_not_shared_time_seasonality():
187+
shared_season = st.TimeSeasonality(
188+
season_length=3,
189+
duration=1,
190+
innovations=True,
191+
name="shared",
192+
state_names=["season_1", "season_2", "season_3"],
193+
observed_state_names=["data_1", "data_2"],
194+
remove_first_state=False,
195+
share_states=True,
196+
)
197+
individual_season = st.TimeSeasonality(
198+
season_length=3,
199+
duration=1,
200+
innovations=False,
201+
name="individual",
202+
state_names=["season_1", "season_2", "season_3"],
203+
observed_state_names=["data_1", "data_2"],
204+
remove_first_state=True,
205+
share_states=False,
206+
)
207+
mod = (shared_season + individual_season).build(verbose=False)
208+
209+
assert mod.k_endog == 2
210+
assert mod.k_states == 7
211+
assert mod.k_posdef == 1
212+
213+
assert mod.coords["state_shared"] == ["season_1", "season_2", "season_3"]
214+
assert mod.coords["state_individual"] == ["season_2", "season_3"]
215+
216+
assert mod.state_names == [
217+
"season_1[shared_shared]",
218+
"season_2[shared_shared]",
219+
"season_3[shared_shared]",
220+
"season_2[data_1]",
221+
"season_3[data_1]",
222+
"season_2[data_2]",
223+
"season_3[data_2]",
224+
]
225+
226+
Z, T, R = pytensor.function(
227+
[], [mod.ssm["design"], mod.ssm["transition"], mod.ssm["selection"]], mode="FAST_COMPILE"
228+
)()
229+
230+
np.testing.assert_allclose(
231+
np.array([[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]]), Z
232+
)
233+
234+
np.testing.assert_allclose(
235+
np.array(
236+
[
237+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
238+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
239+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
240+
[0.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0],
241+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
242+
[0.0, 0.0, 0.0, 0.0, 0.0, -1.0, -1.0],
243+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
244+
]
245+
),
246+
T,
247+
)
248+
249+
np.testing.assert_allclose(np.array([[1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]), R)
250+
251+
150252
@pytest.mark.parametrize("d1, d2", [(1, 1), (1, 3), (3, 1), (3, 3)])
151253
def test_add_two_time_seasonality_different_observed(rng, d1, d2):
152254
mod1 = st.TimeSeasonality(

0 commit comments

Comments
 (0)