Skip to content

Commit 9a472f0

Browse files
authored
Merge branch 'main' into shared-multivariate
2 parents 1341c52 + e375978 commit 9a472f0

File tree

4 files changed

+463
-170
lines changed

4 files changed

+463
-170
lines changed

notebooks/Structural Timeseries Modeling.ipynb

Lines changed: 445 additions & 155 deletions
Large diffs are not rendered by default.

pymc_extras/statespace/models/structural/components/cycle.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def make_symbolic_graph(self) -> None:
214214
self.ssm["selection", :, :] = pt.as_tensor_variable(selection_matrix)
215215

216216
init_state = self.make_and_register_variable(
217-
f"{self.name}",
217+
f"params_{self.name}",
218218
shape=(k_endog_effective, 2) if k_endog_effective > 1 else (self.k_states,),
219219
)
220220
self.ssm["initial_state", :] = init_state.ravel()
@@ -264,26 +264,26 @@ def populate_component_properties(self):
264264
for name in base_names
265265
]
266266

267-
self.param_names = [f"{self.name}"]
267+
self.param_names = [f"params_{self.name}"]
268268

269269
if k_endog_effective == 1:
270-
self.param_dims = {self.name: (f"state_{self.name}",)}
270+
self.param_dims = {f"params_{self.name}": (f"state_{self.name}",)}
271271
self.coords = {f"state_{self.name}": base_names}
272272
self.param_info = {
273-
f"{self.name}": {
273+
f"params_{self.name}": {
274274
"shape": (2,),
275275
"constraints": None,
276276
"dims": (f"state_{self.name}",),
277277
}
278278
}
279279
else:
280-
self.param_dims = {self.name: (f"endog_{self.name}", f"state_{self.name}")}
280+
self.param_dims = {f"params_{self.name}": (f"endog_{self.name}", f"state_{self.name}")}
281281
self.coords = {
282282
f"state_{self.name}": [f"Cos_{self.name}", f"Sin_{self.name}"],
283283
f"endog_{self.name}": self.observed_state_names,
284284
}
285285
self.param_info = {
286-
f"{self.name}": {
286+
f"params_{self.name}": {
287287
"shape": (k_endog_effective, 2),
288288
"constraints": None,
289289
"dims": (f"endog_{self.name}", f"state_{self.name}"),
@@ -295,7 +295,7 @@ def populate_component_properties(self):
295295
self.param_info[f"length_{self.name}"] = {
296296
"shape": () if k_endog_effective == 1 else (k_endog_effective,),
297297
"constraints": "Positive, non-zero",
298-
"dims": None if k_endog_effective == 1 else f"endog_{self.name}",
298+
"dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
299299
}
300300

301301
if self.dampen:

tests/statespace/models/structural/components/test_cycle.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_cycle_component_deterministic(rng):
2222
cycle = st.CycleComponent(
2323
name="cycle", cycle_length=12, estimate_cycle_length=False, innovations=False
2424
)
25-
params = {"cycle": np.array([1.0, 1.0], dtype=config.floatX)}
25+
params = {"params_cycle": np.array([1.0, 1.0], dtype=config.floatX)}
2626
x, y = simulate_from_numpy_model(cycle, rng, params, steps=12 * 12)
2727

2828
assert_pattern_repeats(y, 12, atol=ATOL, rtol=RTOL)
@@ -32,7 +32,10 @@ def test_cycle_component_with_dampening(rng):
3232
cycle = st.CycleComponent(
3333
name="cycle", cycle_length=12, estimate_cycle_length=False, innovations=False, dampen=True
3434
)
35-
params = {"cycle": np.array([10.0, 10.0], dtype=config.floatX), "dampening_factor_cycle": 0.75}
35+
params = {
36+
"params_cycle": np.array([10.0, 10.0], dtype=config.floatX),
37+
"dampening_factor_cycle": 0.75,
38+
}
3639
x, y = simulate_from_numpy_model(cycle, rng, params, steps=100)
3740

3841
# check that cycle dampens to zero over time
@@ -44,7 +47,7 @@ def test_cycle_component_with_innovations_and_cycle_length(rng):
4447
name="cycle", estimate_cycle_length=True, innovations=True, dampen=True
4548
)
4649
params = {
47-
"cycle": np.array([1.0, 1.0], dtype=config.floatX),
50+
"params_cycle": np.array([1.0, 1.0], dtype=config.floatX),
4851
"length_cycle": 12.0,
4952
"dampening_factor_cycle": 0.95,
5053
"sigma_cycle": 1.0,
@@ -64,7 +67,7 @@ def test_cycle_multivariate_deterministic(rng):
6467
innovations=False,
6568
observed_state_names=["data_1", "data_2", "data_3"],
6669
)
67-
params = {"cycle": np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=config.floatX)}
70+
params = {"params_cycle": np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=config.floatX)}
6871
x, y = simulate_from_numpy_model(cycle, rng, params, steps=12 * 12)
6972

7073
# Check that each variable has a cyclical pattern with the expected period
@@ -139,7 +142,7 @@ def test_cycle_multivariate_with_dampening(rng):
139142
observed_state_names=["data_1", "data_2", "data_3"],
140143
)
141144
params = {
142-
"cycle": np.array([[10.0, 10.0], [20.0, 20.0], [30.0, 30.0]], dtype=config.floatX),
145+
"params_cycle": np.array([[10.0, 10.0], [20.0, 20.0], [30.0, 30.0]], dtype=config.floatX),
143146
"dampening_factor_cycle": 0.75,
144147
}
145148
x, y = simulate_from_numpy_model(cycle, rng, params, steps=100)
@@ -167,7 +170,7 @@ def test_cycle_multivariate_with_innovations_and_cycle_length(rng):
167170
observed_state_names=["data_1", "data_2", "data_3"],
168171
)
169172
params = {
170-
"cycle": np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=config.floatX),
173+
"params_cycle": np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=config.floatX),
171174
"length_cycle": 12.0,
172175
"dampening_factor_cycle": 0.95,
173176
"sigma_cycle": np.array([0.5, 1.0, 1.5]), # different innov variances per var

tests/statespace/models/structural/test_against_statsmodels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,8 @@ def create_structural_model_and_equivalent_statsmodel(
370370
params["length_cycle"] = cycle_length
371371

372372
init_cycle = rng.normal(size=(2,)).astype(floatX)
373-
params["cycle"] = init_cycle
374-
expected_param_dims["cycle"] += ("state_cycle",)
373+
params["params_cycle"] = init_cycle
374+
expected_param_dims["params_cycle"] += ("state_cycle",)
375375

376376
state_names = ["Cos_cycle", "Sin_cycle"]
377377
expected_coords["state_cycle"] += state_names

0 commit comments

Comments
 (0)