Skip to content

Commit 414017e

Browse files
Fix seasonality dims, coord and params (#559)
* Fix dims for multivariate time seasonality * Rename coefs to params * Fix params state names in FrequencySeasonality * Fix coefs_ to params_ renaming in broader tests * Tweak state names and shock names * Update tests * Simplify tests --------- Co-authored-by: jessegrabowski <[email protected]>
1 parent 2d16ad0 commit 414017e

File tree

5 files changed

+195
-92
lines changed

5 files changed

+195
-92
lines changed

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ class TimeSeasonality(Component):
212212
sigma_level_trend = pm.HalfNormal(
213213
"sigma_level_trend", sigma=1e-6, dims=ss_mod.param_dims["sigma_level_trend"]
214214
)
215-
coefs_annual = pm.Normal("coefs_annual", sigma=1e-2, dims=ss_mod.param_dims["coefs_annual"])
215+
params_annual = pm.Normal("params_annual", sigma=1e-2, dims=ss_mod.param_dims["params_annual"])
216216
217217
ss_mod.build_statespace_graph(data)
218218
idata = pm.sample(
@@ -298,10 +298,10 @@ def populate_component_properties(self):
298298
for endog_name in self.observed_state_names
299299
for state_name in self.provided_state_names
300300
]
301-
self.param_names = [f"coefs_{self.name}"]
301+
self.param_names = [f"params_{self.name}"]
302302

303303
self.param_info = {
304-
f"coefs_{self.name}": {
304+
f"params_{self.name}": {
305305
"shape": (k_states,) if k_endog == 1 else (k_endog, k_states),
306306
"constraints": None,
307307
"dims": (f"state_{self.name}",)
@@ -311,7 +311,7 @@ def populate_component_properties(self):
311311
}
312312

313313
self.param_dims = {
314-
f"coefs_{self.name}": (f"state_{self.name}",)
314+
f"params_{self.name}": (f"state_{self.name}",)
315315
if k_endog == 1
316316
else (f"endog_{self.name}", f"state_{self.name}")
317317
}
@@ -327,12 +327,14 @@ def populate_component_properties(self):
327327

328328
if self.innovations:
329329
self.param_names += [f"sigma_{self.name}"]
330+
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
330331
self.param_info[f"sigma_{self.name}"] = {
331-
"shape": (),
332+
"shape": () if k_endog == 1 else (k_endog,),
332333
"constraints": "Positive",
333-
"dims": None,
334+
"dims": None if k_endog == 1 else (f"endog_{self.name}",),
334335
}
335-
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
336+
if k_endog > 1:
337+
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
336338

337339
def make_symbolic_graph(self) -> None:
338340
k_states = self.k_states // self.k_endog
@@ -377,7 +379,7 @@ def make_symbolic_graph(self) -> None:
377379
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
378380

379381
initial_states = self.make_and_register_variable(
380-
f"coefs_{self.name}",
382+
f"params_{self.name}",
381383
shape=(k_unique_states,) if k_endog == 1 else (k_endog, k_unique_states),
382384
)
383385
if k_endog == 1:
@@ -506,7 +508,7 @@ def make_symbolic_graph(self) -> None:
506508
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
507509

508510
init_state = self.make_and_register_variable(
509-
f"{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
511+
f"params_{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
510512
)
511513

512514
init_state_idx = np.concatenate(
@@ -535,19 +537,30 @@ def make_symbolic_graph(self) -> None:
535537
def populate_component_properties(self):
536538
k_endog = self.k_endog
537539
n_coefs = self.n_coefs
538-
k_states = self.k_states // k_endog
539540

540541
self.state_names = [
541-
f"{f}_{self.name}_{i}[{obs_state_name}]"
542+
f"{f}_{i}_{self.name}[{obs_state_name}]"
542543
for obs_state_name in self.observed_state_names
543544
for i in range(self.n)
544545
for f in ["Cos", "Sin"]
545546
]
546-
self.param_names = [f"{self.name}"]
547+
# determine which state names correspond to parameters
548+
# all endog variables use same state structure, so we just need
549+
# the first n_coefs state names (which may be less than total if saturated)
550+
param_state_names = [f"{f}_{i}_{self.name}" for i in range(self.n) for f in ["Cos", "Sin"]][
551+
:n_coefs
552+
]
553+
554+
self.param_names = [f"params_{self.name}"]
555+
556+
self.param_dims = {
557+
f"params_{self.name}": (f"state_{self.name}",)
558+
if k_endog == 1
559+
else (f"endog_{self.name}", f"state_{self.name}")
560+
}
547561

548-
self.param_dims = {self.name: (f"state_{self.name}",)}
549562
self.param_info = {
550-
f"{self.name}": {
563+
f"params_{self.name}": {
551564
"shape": (n_coefs,) if k_endog == 1 else (k_endog, n_coefs),
552565
"constraints": None,
553566
"dims": (f"state_{self.name}",)
@@ -556,23 +569,22 @@ def populate_component_properties(self):
556569
}
557570
}
558571

559-
# Regardless of whether the fourier basis are saturated, there will always be one symbolic state per basis.
560-
# That's why the self.states is just a simple loop over everything. But when saturated, one of those states
561-
# doesn't have an associated **parameter**, so the coords need to be adjusted to reflect this.
562-
init_state_idx = np.concatenate(
563-
[
564-
np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
565-
for i in range(k_endog)
566-
],
567-
axis=0,
572+
self.coords = (
573+
{f"state_{self.name}": param_state_names}
574+
if k_endog == 1
575+
else {
576+
f"endog_{self.name}": self.observed_state_names,
577+
f"state_{self.name}": param_state_names,
578+
}
568579
)
569-
self.coords = {f"state_{self.name}": [self.state_names[i] for i in init_state_idx]}
570580

571581
if self.innovations:
572-
self.shock_names = self.state_names.copy()
573582
self.param_names += [f"sigma_{self.name}"]
583+
self.shock_names = self.state_names.copy()
574584
self.param_info[f"sigma_{self.name}"] = {
575-
"shape": () if k_endog == 1 else (k_endog, n_coefs),
585+
"shape": () if k_endog == 1 else (k_endog,),
576586
"constraints": "Positive",
577587
"dims": None if k_endog == 1 else (f"endog_{self.name}",),
578588
}
589+
if k_endog > 1:
590+
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)

pymc_extras/statespace/models/structural/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class StructuralTimeSeries(PyMCStateSpace):
120120
initial_trend = pm.Normal('initial_trend', sigma=10, dims=ss_mod.param_dims['initial_trend'])
121121
sigma_trend = pm.HalfNormal('sigma_trend', sigma=1, dims=ss_mod.param_dims['sigma_trend'])
122122
123-
seasonal_coefs = pm.Normal('seasonal_coefs', sigma=1, dims=ss_mod.param_dims['seasonal_coefs'])
123+
seasonal_coefs = pm.Normal('params_seasonal', sigma=1, dims=ss_mod.param_dims['params_seasonal'])
124124
sigma_seasonal = pm.HalfNormal('sigma_seasonal', sigma=1)
125125
126126
sigma_obs = pm.Exponential('sigma_obs', 1, dims=ss_mod.param_dims['sigma_obs'])

0 commit comments

Comments
 (0)