Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 38 additions & 26 deletions pymc_extras/statespace/models/structural/components/seasonality.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class TimeSeasonality(Component):
sigma_level_trend = pm.HalfNormal(
"sigma_level_trend", sigma=1e-6, dims=ss_mod.param_dims["sigma_level_trend"]
)
coefs_annual = pm.Normal("coefs_annual", sigma=1e-2, dims=ss_mod.param_dims["coefs_annual"])
params_annual = pm.Normal("params_annual", sigma=1e-2, dims=ss_mod.param_dims["params_annual"])

ss_mod.build_statespace_graph(data)
idata = pm.sample(
Expand Down Expand Up @@ -298,10 +298,10 @@ def populate_component_properties(self):
for endog_name in self.observed_state_names
for state_name in self.provided_state_names
]
self.param_names = [f"coefs_{self.name}"]
self.param_names = [f"params_{self.name}"]

self.param_info = {
f"coefs_{self.name}": {
f"params_{self.name}": {
"shape": (k_states,) if k_endog == 1 else (k_endog, k_states),
"constraints": None,
"dims": (f"state_{self.name}",)
Expand All @@ -311,7 +311,7 @@ def populate_component_properties(self):
}

self.param_dims = {
f"coefs_{self.name}": (f"state_{self.name}",)
f"params_{self.name}": (f"state_{self.name}",)
if k_endog == 1
else (f"endog_{self.name}", f"state_{self.name}")
}
Expand All @@ -327,12 +327,14 @@ def populate_component_properties(self):

if self.innovations:
self.param_names += [f"sigma_{self.name}"]
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
self.param_info[f"sigma_{self.name}"] = {
"shape": (),
"shape": () if k_endog == 1 else (k_endog,),
"constraints": "Positive",
"dims": None,
"dims": None if k_endog == 1 else (f"endog_{self.name}",),
}
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
if k_endog > 1:
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)

def make_symbolic_graph(self) -> None:
k_states = self.k_states // self.k_endog
Expand Down Expand Up @@ -377,7 +379,7 @@ def make_symbolic_graph(self) -> None:
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])

initial_states = self.make_and_register_variable(
f"coefs_{self.name}",
f"params_{self.name}",
shape=(k_unique_states,) if k_endog == 1 else (k_endog, k_unique_states),
)
if k_endog == 1:
Expand Down Expand Up @@ -506,7 +508,7 @@ def make_symbolic_graph(self) -> None:
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])

init_state = self.make_and_register_variable(
f"{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
f"params_{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
)

init_state_idx = np.concatenate(
Expand Down Expand Up @@ -535,19 +537,30 @@ def make_symbolic_graph(self) -> None:
def populate_component_properties(self):
k_endog = self.k_endog
n_coefs = self.n_coefs
k_states = self.k_states // k_endog

self.state_names = [
f"{f}_{self.name}_{i}[{obs_state_name}]"
f"{f}_{i}_{self.name}[{obs_state_name}]"
for obs_state_name in self.observed_state_names
for i in range(self.n)
for f in ["Cos", "Sin"]
]
self.param_names = [f"{self.name}"]
# determine which state names correspond to parameters
# all endog variables use same state structure, so we just need
# the first n_coefs state names (which may be less than total if saturated)
param_state_names = [f"{f}_{i}_{self.name}" for i in range(self.n) for f in ["Cos", "Sin"]][
:n_coefs
]

self.param_names = [f"params_{self.name}"]

self.param_dims = {
f"params_{self.name}": (f"state_{self.name}",)
if k_endog == 1
else (f"endog_{self.name}", f"state_{self.name}")
}

self.param_dims = {self.name: (f"state_{self.name}",)}
self.param_info = {
f"{self.name}": {
f"params_{self.name}": {
"shape": (n_coefs,) if k_endog == 1 else (k_endog, n_coefs),
"constraints": None,
"dims": (f"state_{self.name}",)
Expand All @@ -556,23 +569,22 @@ def populate_component_properties(self):
}
}

# Regardless of whether the fourier basis are saturated, there will always be one symbolic state per basis.
# That's why the self.states is just a simple loop over everything. But when saturated, one of those states
# doesn't have an associated **parameter**, so the coords need to be adjusted to reflect this.
init_state_idx = np.concatenate(
[
np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
for i in range(k_endog)
],
axis=0,
self.coords = (
{f"state_{self.name}": param_state_names}
if k_endog == 1
else {
f"endog_{self.name}": self.observed_state_names,
f"state_{self.name}": param_state_names,
}
)
self.coords = {f"state_{self.name}": [self.state_names[i] for i in init_state_idx]}

if self.innovations:
self.shock_names = self.state_names.copy()
self.param_names += [f"sigma_{self.name}"]
self.shock_names = self.state_names.copy()
self.param_info[f"sigma_{self.name}"] = {
"shape": () if k_endog == 1 else (k_endog, n_coefs),
"shape": () if k_endog == 1 else (k_endog,),
"constraints": "Positive",
"dims": None if k_endog == 1 else (f"endog_{self.name}",),
}
if k_endog > 1:
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
2 changes: 1 addition & 1 deletion pymc_extras/statespace/models/structural/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class StructuralTimeSeries(PyMCStateSpace):
initial_trend = pm.Normal('initial_trend', sigma=10, dims=ss_mod.param_dims['initial_trend'])
sigma_trend = pm.HalfNormal('sigma_trend', sigma=1, dims=ss_mod.param_dims['sigma_trend'])

seasonal_coefs = pm.Normal('seasonal_coefs', sigma=1, dims=ss_mod.param_dims['seasonal_coefs'])
seasonal_coefs = pm.Normal('params_seasonal', sigma=1, dims=ss_mod.param_dims['params_seasonal'])
sigma_seasonal = pm.HalfNormal('sigma_seasonal', sigma=1)

sigma_obs = pm.Exponential('sigma_obs', 1, dims=ss_mod.param_dims['sigma_obs'])
Expand Down
Loading
Loading