Skip to content

Commit 49ff101

Browse files
committed
Fix params state names in FrequencySeasonality
1 parent 7026256 commit 49ff101

File tree

2 files changed

+321
-55
lines changed

2 files changed

+321
-55
lines changed

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -327,13 +327,14 @@ def populate_component_properties(self):
327327

328328
if self.innovations:
329329
self.param_names += [f"sigma_{self.name}"]
330+
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
330331
self.param_info[f"sigma_{self.name}"] = {
331332
"shape": () if k_endog == 1 else (k_endog,),
332333
"constraints": "Positive",
333334
"dims": None if k_endog == 1 else (f"endog_{self.name}",),
334335
}
335-
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
336-
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
336+
if k_endog > 1:
337+
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
337338

338339
def make_symbolic_graph(self) -> None:
339340
k_states = self.k_states // self.k_endog
@@ -507,7 +508,7 @@ def make_symbolic_graph(self) -> None:
507508
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
508509

509510
init_state = self.make_and_register_variable(
510-
f"{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
511+
f"params_{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
511512
)
512513

513514
init_state_idx = np.concatenate(
@@ -536,19 +537,30 @@ def make_symbolic_graph(self) -> None:
536537
def populate_component_properties(self):
537538
k_endog = self.k_endog
538539
n_coefs = self.n_coefs
539-
k_states = self.k_states // k_endog
540540

541541
self.state_names = [
542542
f"{f}_{self.name}_{i}[{obs_state_name}]"
543543
for obs_state_name in self.observed_state_names
544544
for i in range(self.n)
545545
for f in ["Cos", "Sin"]
546546
]
547-
self.param_names = [f"{self.name}"]
547+
# determine which state names correspond to parameters
548+
# all endog variables use same state structure, so we just need
549+
# the first n_coefs state names (which may be less than total if saturated)
550+
param_state_names = [f"{f}_{self.name}_{i}" for i in range(self.n) for f in ["Cos", "Sin"]][
551+
:n_coefs
552+
]
553+
554+
self.param_names = [f"params_{self.name}"]
555+
556+
self.param_dims = {
557+
f"params_{self.name}": (f"state_{self.name}",)
558+
if k_endog == 1
559+
else (f"endog_{self.name}", f"state_{self.name}")
560+
}
548561

549-
self.param_dims = {self.name: (f"state_{self.name}",)}
550562
self.param_info = {
551-
f"{self.name}": {
563+
f"params_{self.name}": {
552564
"shape": (n_coefs,) if k_endog == 1 else (k_endog, n_coefs),
553565
"constraints": None,
554566
"dims": (f"state_{self.name}",)
@@ -557,23 +569,22 @@ def populate_component_properties(self):
557569
}
558570
}
559571

560-
# Regardless of whether the fourier basis are saturated, there will always be one symbolic state per basis.
561-
# That's why the self.states is just a simple loop over everything. But when saturated, one of those states
562-
# doesn't have an associated **parameter**, so the coords need to be adjusted to reflect this.
563-
init_state_idx = np.concatenate(
564-
[
565-
np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
566-
for i in range(k_endog)
567-
],
568-
axis=0,
572+
self.coords = (
573+
{f"state_{self.name}": param_state_names}
574+
if k_endog == 1
575+
else {
576+
f"endog_{self.name}": self.observed_state_names,
577+
f"state_{self.name}": param_state_names,
578+
}
569579
)
570-
self.coords = {f"state_{self.name}": [self.state_names[i] for i in init_state_idx]}
571580

572581
if self.innovations:
573-
self.shock_names = self.state_names.copy()
574582
self.param_names += [f"sigma_{self.name}"]
583+
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
575584
self.param_info[f"sigma_{self.name}"] = {
576-
"shape": () if k_endog == 1 else (k_endog, n_coefs),
585+
"shape": () if k_endog == 1 else (k_endog,),
577586
"constraints": "Positive",
578587
"dims": None if k_endog == 1 else (f"endog_{self.name}",),
579588
}
589+
if k_endog > 1:
590+
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)

0 commit comments

Comments
 (0)