Skip to content

Commit 7581f04

Browse files
Allow multiple observed in FrequencySeasonality component
1 parent a8564b7 commit 7581f04

File tree

2 files changed

+253
-21
lines changed

2 files changed

+253
-21
lines changed

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,10 @@ def __init__(
301301
if observed_state_names is None:
302302
observed_state_names = ["data"]
303303

304+
k_endog = len(observed_state_names)
305+
304306
if n is None:
305-
n = int(season_length // 2)
307+
n = int(season_length / 2)
306308
if name is None:
307309
name = f"Frequency[s={season_length}, n={n}]"
308310

@@ -319,58 +321,97 @@ def __init__(
319321

320322
obs_state_idx = np.zeros(k_states)
321323
obs_state_idx[slice(0, k_states, 2)] = 1
324+
obs_state_idx = np.tile(obs_state_idx, k_endog)
322325

323326
super().__init__(
324327
name=name,
325-
k_endog=1,
326-
k_states=k_states,
327-
k_posdef=k_states * int(self.innovations),
328+
k_endog=k_endog,
329+
k_states=k_states * k_endog,
330+
k_posdef=k_states * int(self.innovations) * k_endog,
328331
observed_state_names=observed_state_names,
329332
measurement_error=False,
330333
combine_hidden_states=True,
331334
obs_state_idxs=obs_state_idx,
332335
)
333336

334337
def make_symbolic_graph(self) -> None:
335-
self.ssm["design", 0, slice(0, self.k_states, 2)] = 1
338+
k_endog = self.k_endog
339+
k_states = self.k_states // k_endog
340+
k_posdef = self.k_posdef // k_endog
341+
n_coefs = self.n_coefs
336342

337-
init_state = self.make_and_register_variable(f"{self.name}", shape=(self.n_coefs,))
343+
Z = pt.zeros((1, k_states))[0, slice(0, k_states, 2)].set(1.0)
338344

339-
init_state_idx = np.arange(self.n_coefs, dtype=int)
340-
self.ssm["initial_state", init_state_idx] = init_state
345+
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
346+
347+
init_state = self.make_and_register_variable(
348+
f"{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
349+
)
350+
351+
init_state_idx = np.concatenate(
352+
[
353+
np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
354+
for i in range(k_endog)
355+
],
356+
axis=0,
357+
)
358+
359+
self.ssm["initial_state", init_state_idx] = init_state.ravel()
341360

342361
T_mats = [_frequency_transition_block(self.season_length, j + 1) for j in range(self.n)]
343362
T = pt.linalg.block_diag(*T_mats)
344-
self.ssm["transition", :, :] = T
363+
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
345364

346365
if self.innovations:
347-
sigma_season = self.make_and_register_variable(f"sigma_{self.name}", shape=())
348-
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_season**2
349-
self.ssm["selection", :, :] = np.eye(self.k_states)
366+
sigma_season = self.make_and_register_variable(
367+
f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
368+
)
369+
self.ssm["selection", :, :] = pt.eye(self.k_states)
370+
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * pt.repeat(
371+
sigma_season**2, k_posdef
372+
)
350373

351374
def populate_component_properties(self):
352-
self.state_names = [f"{self.name}_{f}_{i}" for i in range(self.n) for f in ["Cos", "Sin"]]
375+
k_endog = self.k_endog
376+
n_coefs = self.n_coefs
377+
k_states = self.k_states // k_endog
378+
379+
self.state_names = [
380+
f"{self.name}_{f}_{i}[{obs_state_name}]"
381+
for obs_state_name in self.observed_state_names
382+
for i in range(self.n)
383+
for f in ["Cos", "Sin"]
384+
]
353385
self.param_names = [f"{self.name}"]
354386

355387
self.param_dims = {self.name: (f"{self.name}_state",)}
356388
self.param_info = {
357389
f"{self.name}": {
358-
"shape": (self.k_states - int(self.last_state_not_identified),),
390+
"shape": (n_coefs,) if k_endog == 1 else (k_endog, n_coefs),
359391
"constraints": None,
360-
"dims": (f"{self.name}_state",),
392+
"dims": (f"{self.name}_state",)
393+
if k_endog == 1
394+
else (f"{self.name}_endog", f"{self.name}_state"),
361395
}
362396
}
363397

364-
init_state_idx = np.arange(self.k_states, dtype=int)
365-
if self.last_state_not_identified:
366-
init_state_idx = init_state_idx[:-1]
398+
# Regardless of whether the fourier basis are saturated, there will always be one symbolic state per basis.
399+
# That's why the self.states is just a simple loop over everything. But when saturated, one of those states
400+
# doesn't have an associated **parameter**, so the coords need to be adjusted to reflect this.
401+
init_state_idx = np.concatenate(
402+
[
403+
np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
404+
for i in range(k_endog)
405+
],
406+
axis=0,
407+
)
367408
self.coords = {f"{self.name}_state": [self.state_names[i] for i in init_state_idx]}
368409

369410
if self.innovations:
370411
self.shock_names = self.state_names.copy()
371412
self.param_names += [f"sigma_{self.name}"]
372413
self.param_info[f"sigma_{self.name}"] = {
373-
"shape": (),
414+
"shape": () if k_endog == 1 else (k_endog, n_coefs),
374415
"constraints": "Positive",
375-
"dims": None,
416+
"dims": None if k_endog == 1 else (f"{self.name}_endog",),
376417
}

tests/statespace/models/structural/components/test_seasonality.py

Lines changed: 192 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def test_frequency_seasonality(n, s, rng):
236236
assert_pattern_repeats(y, T, atol=ATOL, rtol=RTOL)
237237

238238
# Check coords
239-
mod.build(verbose=False)
239+
mod = mod.build(verbose=False)
240240
_assert_basic_coords_correct(mod)
241241
if n is None:
242242
n = int(s // 2)
@@ -246,3 +246,194 @@ def test_frequency_seasonality(n, s, rng):
246246
if s / n == 2.0:
247247
states.pop()
248248
assert mod.coords["season_state"] == states
249+
250+
251+
def test_frequency_seasonality_multiple_observed(rng):
252+
observed_state_names = ["data_1", "data_2"]
253+
season_length = 4
254+
mod = st.FrequencySeasonality(
255+
season_length=season_length,
256+
n=None,
257+
name="season",
258+
innovations=True,
259+
observed_state_names=observed_state_names,
260+
)
261+
expected_state_names = [
262+
"season_Cos_0[data_1]",
263+
"season_Sin_0[data_1]",
264+
"season_Cos_1[data_1]",
265+
"season_Sin_1[data_1]",
266+
"season_Cos_0[data_2]",
267+
"season_Sin_0[data_2]",
268+
"season_Cos_1[data_2]",
269+
"season_Sin_1[data_2]",
270+
]
271+
assert mod.state_names == expected_state_names
272+
assert mod.shock_names == [
273+
"season_Cos_0[data_1]",
274+
"season_Sin_0[data_1]",
275+
"season_Cos_1[data_1]",
276+
"season_Sin_1[data_1]",
277+
"season_Cos_0[data_2]",
278+
"season_Sin_0[data_2]",
279+
"season_Cos_1[data_2]",
280+
"season_Sin_1[data_2]",
281+
]
282+
283+
# Simulate
284+
x0 = np.zeros((2, 3), dtype=config.floatX)
285+
x0[0, 0] = 1.0
286+
x0[1, 0] = 2.0
287+
params = {"season": x0, "sigma_season": np.zeros(2, dtype=config.floatX)}
288+
x, y = simulate_from_numpy_model(mod, rng, params, steps=12)
289+
290+
# Check periodicity for each observed series
291+
assert_pattern_repeats(y[:, 0], 4, atol=ATOL, rtol=RTOL)
292+
assert_pattern_repeats(y[:, 1], 4, atol=ATOL, rtol=RTOL)
293+
294+
mod = mod.build(verbose=False)
295+
assert list(mod.coords["season_state"]) == [
296+
"season_Cos_0[data_1]",
297+
"season_Sin_0[data_1]",
298+
"season_Cos_1[data_1]",
299+
"season_Cos_0[data_2]",
300+
"season_Sin_0[data_2]",
301+
"season_Cos_1[data_2]",
302+
]
303+
304+
x0_sym, *_, T_sym, Z_sym, R_sym, _, Q_sym = mod._unpack_statespace_with_placeholders()
305+
input_vars = explicit_graph_inputs([x0_sym, T_sym, Z_sym, R_sym, Q_sym])
306+
fn = pytensor.function(
307+
inputs=list(input_vars),
308+
outputs=[x0_sym, T_sym, Z_sym, R_sym, Q_sym],
309+
mode="FAST_COMPILE",
310+
)
311+
params["sigma_season"] = np.array([0.1, 0.8], dtype=config.floatX)
312+
x0_v, T_v, Z_v, R_v, Q_v = fn(**params)
313+
314+
# x0 should be raveled into a single vector, with data_1 states first, then data_2 states
315+
np.testing.assert_allclose(
316+
x0_v, np.array([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0]), atol=ATOL, rtol=RTOL
317+
)
318+
319+
# T_v shape: (8, 8) (k_endog * k_states)
320+
# The transition matrix is block diagonal, each block is:
321+
# For n=2, season_length=4:
322+
# lambda_1 = 2*pi*1/4 = pi/2, cos(pi/2)=0, sin(pi/2)=1
323+
# lambda_2 = 2*pi*2/4 = pi, cos(pi)=-1, sin(pi)=0
324+
# Block 1 (Cos_0, Sin_0):
325+
# [[cos(pi/2), sin(pi/2)],
326+
# [-sin(pi/2), cos(pi/2)]] = [[0, 1], [-1, 0]]
327+
# Block 2 (Cos_1, Sin_1):
328+
# [[-1, 0], [0, -1]]
329+
expected_T_block1 = np.array([[0.0, 1.0], [-1.0, 0.0]])
330+
expected_T_block2 = np.array([[-1.0, 0.0], [0.0, -1.0]])
331+
expected_T = np.zeros((8, 8))
332+
# data_1
333+
expected_T[0:2, 0:2] = expected_T_block1
334+
expected_T[2:4, 2:4] = expected_T_block2
335+
# data_2
336+
expected_T[4:6, 4:6] = expected_T_block1
337+
expected_T[6:8, 6:8] = expected_T_block2
338+
np.testing.assert_allclose(T_v, expected_T, atol=ATOL, rtol=RTOL)
339+
340+
# Only the first two states (one sin and one cos component) of each observed series are observed
341+
expected_Z = np.zeros((2, 8))
342+
expected_Z[0, 0] = 1.0
343+
expected_Z[0, 2] = 1.0
344+
expected_Z[1, 4] = 1.0
345+
expected_Z[1, 6] = 1.0
346+
np.testing.assert_allclose(Z_v, expected_Z, atol=ATOL, rtol=RTOL)
347+
348+
np.testing.assert_allclose(R_v, np.eye(8), atol=ATOL, rtol=RTOL)
349+
350+
Q_diag = np.diag(Q_v)
351+
expected_Q_diag = np.r_[np.full(4, 0.1**2), np.full(4, 0.8**2)]
352+
np.testing.assert_allclose(Q_diag, expected_Q_diag, atol=ATOL, rtol=RTOL)
353+
354+
355+
def test_add_two_frequency_seasonality_different_observed(rng):
356+
mod1 = st.FrequencySeasonality(
357+
season_length=4,
358+
n=2, # saturated
359+
name="freq1",
360+
innovations=True,
361+
observed_state_names=["data_1"],
362+
)
363+
mod2 = st.FrequencySeasonality(
364+
season_length=6,
365+
n=1, # unsaturated
366+
name="freq2",
367+
innovations=True,
368+
observed_state_names=["data_2"],
369+
)
370+
371+
mod = (mod1 + mod2).build(verbose=False)
372+
373+
params = {
374+
"freq1": np.array([1.0, 0.0, 0.0], dtype=config.floatX),
375+
"freq2": np.array([3.0, 0.0], dtype=config.floatX),
376+
"sigma_freq1": np.array(0.0, dtype=config.floatX),
377+
"sigma_freq2": np.array(0.0, dtype=config.floatX),
378+
"initial_state_cov": np.eye(mod.k_states, dtype=config.floatX),
379+
}
380+
381+
x, y = simulate_from_numpy_model(mod, rng, params, steps=4 * 6 * 3)
382+
383+
assert_pattern_repeats(y[:, 0], 4, atol=ATOL, rtol=RTOL)
384+
assert_pattern_repeats(y[:, 1], 6, atol=ATOL, rtol=RTOL)
385+
386+
assert mod.state_names == [
387+
"freq1_Cos_0[data_1]",
388+
"freq1_Sin_0[data_1]",
389+
"freq1_Cos_1[data_1]",
390+
"freq1_Sin_1[data_1]",
391+
"freq2_Cos_0[data_2]",
392+
"freq2_Sin_0[data_2]",
393+
]
394+
395+
assert mod.shock_names == [
396+
"freq1_Cos_0[data_1]",
397+
"freq1_Sin_0[data_1]",
398+
"freq1_Cos_1[data_1]",
399+
"freq1_Sin_1[data_1]",
400+
"freq2_Cos_0[data_2]",
401+
"freq2_Sin_0[data_2]",
402+
]
403+
404+
x0, *_, T = mod._unpack_statespace_with_placeholders()[:5]
405+
input_vars = explicit_graph_inputs([x0, T])
406+
fn = pytensor.function(
407+
inputs=list(input_vars),
408+
outputs=[x0, T],
409+
mode="FAST_COMPILE",
410+
)
411+
412+
x0_v, T_v = fn(
413+
freq1=np.array([1.0, 0.0, 1.2], dtype=config.floatX),
414+
freq2=np.array([3.0, 0.0], dtype=config.floatX),
415+
)
416+
417+
# Make sure the extra 0 in from the first component (the saturated state) is there!
418+
np.testing.assert_allclose(np.array([1.0, 0.0, 1.2, 0.0, 3.0, 0.0]), x0_v, atol=ATOL, rtol=RTOL)
419+
420+
# Transition matrix is block diagonal: 4x4 for freq1, 2x2 for freq2
421+
# freq1: n=4, lambdas = 2*pi*1/6, 2*pi*2/6
422+
lam1 = 2 * np.pi * 1 / 4
423+
lam2 = 2 * np.pi * 2 / 4
424+
freq1_T1 = np.array([[np.cos(lam1), np.sin(lam1)], [-np.sin(lam1), np.cos(lam1)]])
425+
freq1_T2 = np.array([[np.cos(lam2), np.sin(lam2)], [-np.sin(lam2), np.cos(lam2)]])
426+
freq1_T = np.zeros((4, 4))
427+
428+
# freq2: n=4, lambdas = 2*pi*1/6
429+
lam3 = 2 * np.pi * 1 / 6
430+
freq2_T = np.array([[np.cos(lam3), np.sin(lam3)], [-np.sin(lam3), np.cos(lam3)]])
431+
432+
freq1_T[0:2, 0:2] = freq1_T1
433+
freq1_T[2:4, 2:4] = freq1_T2
434+
435+
expected_T = np.zeros((6, 6))
436+
expected_T[0:4, 0:4] = freq1_T
437+
expected_T[4:6, 4:6] = freq2_T
438+
439+
np.testing.assert_allclose(expected_T, T_v, atol=ATOL, rtol=RTOL)

0 commit comments

Comments
 (0)