Skip to content

Commit 9c14472

Browse files
committed
Improve cycle and seasonal docstrings
1 parent 124e1c3 commit 9c14472

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

pymc_extras/statespace/models/structural/components/cycle.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,26 @@ class CycleComponent(Component):
8585
8686
# Build the structural model
8787
grw = st.LevelTrendComponent(order=1, innovations_order=1)
88-
cycle = st.CycleComponent('business_cycle', estimate_cycle_length=True, dampen=False)
88+
cycle = st.CycleComponent(
89+
"business_cycle", cycle_length=12, estimate_cycle_length=False, innovations=True, dampen=True
90+
)
8991
ss_mod = (grw + cycle).build()
9092
9193
# Estimate with PyMC
9294
with pm.Model(coords=ss_mod.coords) as model:
9395
P0 = pm.Deterministic('P0', pt.eye(ss_mod.k_states), dims=ss_mod.param_dims['P0'])
94-
intitial_trend = pm.Normal('initial_trend', dims=ss_mod.param_dims['initial_trend'])
95-
sigma_trend = pm.HalfNormal('sigma_trend', dims=ss_mod.param_dims['sigma_trend'])
9696
97-
cycle_strength = pm.Normal("business_cycle", dims=ss_mod.param_dims["business_cycle"])
98-
cycle_length = pm.Uniform('business_cycle_length', lower=6, upper=12)
99-
sigma_cycle = pm.HalfNormal('sigma_business_cycle', sigma=1)
97+
initial_level_trend = pm.Normal('initial_level_trend', dims=ss_mod.param_dims['initial_level_trend'])
98+
sigma_level_trend = pm.HalfNormal('sigma_level_trend', dims=ss_mod.param_dims['sigma_level_trend'])
99+
100+
business_cycle = pm.Normal("business_cycle", dims=ss_mod.param_dims["business_cycle"])
101+
dampening = pm.Beta("dampening_factor_business_cycle", 2, 2)
102+
sigma_cycle = pm.HalfNormal("sigma_business_cycle", sigma=1)
100103
101104
ss_mod.build_statespace_graph(data)
102-
idata = pm.sample()
105+
idata = pm.sample(
106+
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "JAX", "gradient_backend": "JAX"}
107+
)
103108
104109
**Multivariate Example:**
105110
Model cycles for multiple economic indicators with variable-specific innovation variances:
@@ -115,26 +120,25 @@ class CycleComponent(Component):
115120
dampen=True,
116121
observed_state_names=['gdp', 'unemployment', 'inflation']
117122
)
118-
119-
# Build the model
120123
ss_mod = cycle.build()
121124
122-
# In PyMC model:
123125
with pm.Model(coords=ss_mod.coords) as model:
124126
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states), dims=ss_mod.param_dims["P0"])
125127
# Initial states: shape (3, 2) for 3 variables, 2 states each
126-
cycle_init = pm.Normal('business_cycle', dims=ss_mod.param_dims["business_cycle"])
128+
business_cycle = pm.Normal('business_cycle', dims=ss_mod.param_dims["business_cycle"])
127129
128130
# Dampening factor: scalar (shared across variables)
129-
dampening = pm.Beta("business_cycle_dampening_factor", 2, 2)
131+
dampening = pm.Beta("dampening_factor_business_cycle", 2, 2)
130132
131133
# Innovation variances: shape (3,) for variable-specific variances
132134
sigma_cycle = pm.HalfNormal(
133135
"sigma_business_cycle", dims=ss_mod.param_dims["sigma_business_cycle"]
134136
)
135137
136138
ss_mod.build_statespace_graph(data)
137-
idata = pm.sample()
139+
idata = pm.sample(
140+
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "JAX", "gradient_backend": "JAX"}
141+
)
138142
139143
References
140144
----------

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,26 @@ class TimeSeasonality(Component):
106106
107107
# Build the structural model
108108
grw = st.LevelTrendComponent(order=1, innovations_order=1)
109-
annual_season = st.TimeSeasonality(season_length=12, name='annual', state_names=state_names, innovations=False)
109+
annual_season = st.TimeSeasonality(
110+
season_length=12, name="annual", state_names=state_names, innovations=False
111+
)
110112
ss_mod = (grw + annual_season).build()
111113
112-
# Estimate with PyMC
113114
with pm.Model(coords=ss_mod.coords) as model:
114115
P0 = pm.Deterministic('P0', pt.eye(ss_mod.k_states) * 10, dims=ss_mod.param_dims['P0'])
115-
intitial_trend = pm.Deterministic('initial_trend', pt.zeros(1), dims=ss_mod.param_dims['initial_trend'])
116-
annual_coefs = pm.Normal('annual_coefs', sigma=1e-2, dims=ss_mod.param_dims['annual_coefs'])
117-
trend_sigmas = pm.HalfNormal('trend_sigmas', sigma=1e-6, dims=ss_mod.param_dims['trend_sigmas'])
116+
117+
initial_level_trend = pm.Deterministic(
118+
"initial_level_trend", pt.zeros(1), dims=ss_mod.param_dims["initial_level_trend"]
119+
)
120+
sigma_level_trend = pm.HalfNormal(
121+
"sigma_level_trend", sigma=1e-6, dims=ss_mod.param_dims["sigma_level_trend"]
122+
)
123+
coefs_annual = pm.Normal("coefs_annual", sigma=1e-2, dims=ss_mod.param_dims["coefs_annual"])
124+
118125
ss_mod.build_statespace_graph(data)
119-
idata = pm.sample(nuts_sampler='numpyro')
126+
idata = pm.sample(
127+
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "JAX", "gradient_backend": "JAX"}
128+
)
120129
121130
References
122131
----------

0 commit comments

Comments
 (0)