@@ -85,21 +85,26 @@ class CycleComponent(Component):
85
85
86
86
# Build the structural model
87
87
grw = st.LevelTrendComponent(order=1, innovations_order=1)
88
- cycle = st.CycleComponent('business_cycle', estimate_cycle_length=True, dampen=False)
88
+ cycle = st.CycleComponent(
89
+ "business_cycle", cycle_length=12, estimate_cycle_length=False, innovations=True, dampen=True
90
+ )
89
91
ss_mod = (grw + cycle).build()
90
92
91
93
# Estimate with PyMC
92
94
with pm.Model(coords=ss_mod.coords) as model:
93
95
P0 = pm.Deterministic('P0', pt.eye(ss_mod.k_states), dims=ss_mod.param_dims['P0'])
94
- intitial_trend = pm.Normal('initial_trend', dims=ss_mod.param_dims['initial_trend'])
95
- sigma_trend = pm.HalfNormal('sigma_trend', dims=ss_mod.param_dims['sigma_trend'])
96
96
97
- cycle_strength = pm.Normal("business_cycle", dims=ss_mod.param_dims["business_cycle"])
98
- cycle_length = pm.Uniform('business_cycle_length', lower=6, upper=12)
99
- sigma_cycle = pm.HalfNormal('sigma_business_cycle', sigma=1)
97
+ initial_level_trend = pm.Normal('initial_level_trend', dims=ss_mod.param_dims['initial_level_trend'])
98
+ sigma_level_trend = pm.HalfNormal('sigma_level_trend', dims=ss_mod.param_dims['sigma_level_trend'])
99
+
100
+ business_cycle = pm.Normal("business_cycle", dims=ss_mod.param_dims["business_cycle"])
101
+ dampening = pm.Beta("dampening_factor_business_cycle", 2, 2)
102
+ sigma_cycle = pm.HalfNormal("sigma_business_cycle", sigma=1)
100
103
101
104
ss_mod.build_statespace_graph(data)
102
- idata = pm.sample()
105
+ idata = pm.sample(
106
+ nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "JAX", "gradient_backend": "JAX"}
107
+ )
103
108
104
109
**Multivariate Example:**
105
110
Model cycles for multiple economic indicators with variable-specific innovation variances:
@@ -115,26 +120,25 @@ class CycleComponent(Component):
115
120
dampen=True,
116
121
observed_state_names=['gdp', 'unemployment', 'inflation']
117
122
)
118
-
119
- # Build the model
120
123
ss_mod = cycle.build()
121
124
122
- # In PyMC model:
123
125
with pm.Model(coords=ss_mod.coords) as model:
124
126
P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states), dims=ss_mod.param_dims["P0"])
125
127
# Initial states: shape (3, 2) for 3 variables, 2 states each
126
- cycle_init = pm.Normal('business_cycle', dims=ss_mod.param_dims["business_cycle"])
128
+ business_cycle = pm.Normal('business_cycle', dims=ss_mod.param_dims["business_cycle"])
127
129
128
130
# Dampening factor: scalar (shared across variables)
129
- dampening = pm.Beta("business_cycle_dampening_factor ", 2, 2)
131
+ dampening = pm.Beta("dampening_factor_business_cycle ", 2, 2)
130
132
131
133
# Innovation variances: shape (3,) for variable-specific variances
132
134
sigma_cycle = pm.HalfNormal(
133
135
"sigma_business_cycle", dims=ss_mod.param_dims["sigma_business_cycle"]
134
136
)
135
137
136
138
ss_mod.build_statespace_graph(data)
137
- idata = pm.sample()
139
+ idata = pm.sample(
140
+ nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "JAX", "gradient_backend": "JAX"}
141
+ )
138
142
139
143
References
140
144
----------
0 commit comments