Skip to content

Commit 08085c7

Browse files
Merge pull request #5 from Dekermanjian/multivariate-structural
naming schema adherence
2 parents 503eec5 + 77c27f4 commit 08085c7

File tree

11 files changed

+138
-148
lines changed

11 files changed

+138
-148
lines changed

notebooks/multivariate_ssm.ipynb

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,17 +1608,7 @@
16081608
},
16091609
{
16101610
"cell_type": "code",
1611-
"execution_count": 32,
1612-
"id": "fcc14c10",
1613-
"metadata": {},
1614-
"outputs": [],
1615-
"source": [
1616-
"state_names = pd.date_range(\"1900-01-01\", \"1900-12-31\", freq=\"MS\").month_name().tolist()"
1617-
]
1618-
},
1619-
{
1620-
"cell_type": "code",
1621-
"execution_count": 4,
1611+
"execution_count": 3,
16221612
"id": "79f703dd",
16231613
"metadata": {},
16241614
"outputs": [],
@@ -1629,7 +1619,7 @@
16291619
},
16301620
{
16311621
"cell_type": "code",
1632-
"execution_count": 5,
1622+
"execution_count": 4,
16331623
"id": "36ce0c20",
16341624
"metadata": {},
16351625
"outputs": [],
@@ -1656,7 +1646,7 @@
16561646
},
16571647
{
16581648
"cell_type": "code",
1659-
"execution_count": 6,
1649+
"execution_count": 5,
16601650
"id": "baf82cfd",
16611651
"metadata": {},
16621652
"outputs": [
@@ -1748,7 +1738,7 @@
17481738
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"430\" y=\"-153.11\" font-family=\"Times,serif\" font-size=\"14.00\">KalmanFilter</text>\n",
17491739
"</g>\n",
17501740
"<!-- P0&#45;&gt;obs -->\n",
1751-
"<g id=\"edge3\" class=\"edge\">\n",
1741+
"<g id=\"edge4\" class=\"edge\">\n",
17521742
"<title>P0&#45;&gt;obs</title>\n",
17531743
"<path fill=\"none\" stroke=\"black\" d=\"M128.49,-275.49C138.4,-260.15 152.65,-242.11 170,-231.32 226.08,-196.43 300.68,-183.15 355.59,-178.23\"/>\n",
17541744
"<polygon fill=\"black\" stroke=\"black\" points=\"355.81,-181.72 365.49,-177.42 355.24,-174.75 355.81,-181.72\"/>\n",
@@ -1776,7 +1766,7 @@
17761766
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"514\" y=\"-282.93\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
17771767
"</g>\n",
17781768
"<!-- annual_coefs&#45;&gt;obs -->\n",
1779-
"<g id=\"edge5\" class=\"edge\">\n",
1769+
"<g id=\"edge3\" class=\"edge\">\n",
17801770
"<title>annual_coefs&#45;&gt;obs</title>\n",
17811771
"<path fill=\"none\" stroke=\"black\" d=\"M489.93,-266.85C480.7,-252.81 470.06,-236.61 460.32,-221.79\"/>\n",
17821772
"<polygon fill=\"black\" stroke=\"black\" points=\"463.32,-219.99 454.91,-213.56 457.47,-223.84 463.32,-219.99\"/>\n",
@@ -1790,7 +1780,7 @@
17901780
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"761\" y=\"-282.93\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
17911781
"</g>\n",
17921782
"<!-- sigma_level_trend&#45;&gt;obs -->\n",
1793-
"<g id=\"edge4\" class=\"edge\">\n",
1783+
"<g id=\"edge5\" class=\"edge\">\n",
17941784
"<title>sigma_level_trend&#45;&gt;obs</title>\n",
17951785
"<path fill=\"none\" stroke=\"black\" d=\"M722.75,-268.18C706.58,-254.9 686.89,-240.75 667,-231.32 615.03,-206.68 551.33,-192.51 503.07,-184.64\"/>\n",
17961786
"<polygon fill=\"black\" stroke=\"black\" points=\"503.69,-181.19 493.27,-183.1 502.61,-188.11 503.69,-181.19\"/>\n",
@@ -1813,10 +1803,10 @@
18131803
"</svg>\n"
18141804
],
18151805
"text/plain": [
1816-
"<graphviz.graphs.Digraph at 0x323e64440>"
1806+
"<graphviz.graphs.Digraph at 0x32028ef90>"
18171807
]
18181808
},
1819-
"execution_count": 6,
1809+
"execution_count": 5,
18201810
"metadata": {},
18211811
"output_type": "execute_result"
18221812
}

pymc_extras/statespace/models/structural/components/cycle.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,12 @@ def make_symbolic_graph(self) -> None:
205205
self.ssm["initial_state", :] = init_state.ravel()
206206

207207
if self.estimate_cycle_length:
208-
lamb = self.make_and_register_variable(f"{self.name}_length", shape=())
208+
lamb = self.make_and_register_variable(f"length_{self.name}", shape=())
209209
else:
210210
lamb = self.cycle_length
211211

212212
if self.dampen:
213-
rho = self.make_and_register_variable(f"{self.name}_dampening_factor", shape=())
213+
rho = self.make_and_register_variable(f"dampening_factor_{self.name}", shape=())
214214
else:
215215
rho = 1
216216

@@ -236,51 +236,51 @@ def make_symbolic_graph(self) -> None:
236236

237237
def populate_component_properties(self):
238238
self.state_names = [
239-
f"{self.name}_{f}[{var_name}]" if self.k_endog > 1 else f"{self.name}_{f}"
239+
f"{f}_{self.name}[{var_name}]" if self.k_endog > 1 else f"{f}_{self.name}"
240240
for var_name in self.observed_state_names
241241
for f in ["Cos", "Sin"]
242242
]
243243

244244
self.param_names = [f"{self.name}"]
245245

246246
if self.k_endog == 1:
247-
self.param_dims = {self.name: (f"{self.name}_state",)}
248-
self.coords = {f"{self.name}_state": self.state_names}
247+
self.param_dims = {self.name: (f"state_{self.name}",)}
248+
self.coords = {f"state_{self.name}": self.state_names}
249249
self.param_info = {
250250
f"{self.name}": {
251251
"shape": (2,),
252252
"constraints": None,
253-
"dims": (f"{self.name}_state",),
253+
"dims": (f"state_{self.name}",),
254254
}
255255
}
256256
else:
257-
self.param_dims = {self.name: (f"{self.name}_endog", f"{self.name}_state")}
257+
self.param_dims = {self.name: (f"endog_{self.name}", f"state_{self.name}")}
258258
self.coords = {
259-
f"{self.name}_state": [f"{self.name}_Cos", f"{self.name}_Sin"],
260-
f"{self.name}_endog": self.observed_state_names,
259+
f"state_{self.name}": [f"Cos_{self.name}", f"Sin_{self.name}"],
260+
f"endog_{self.name}": self.observed_state_names,
261261
}
262262
self.param_info = {
263263
f"{self.name}": {
264264
"shape": (self.k_endog, 2),
265265
"constraints": None,
266-
"dims": (f"{self.name}_endog", f"{self.name}_state"),
266+
"dims": (f"endog_{self.name}", f"state_{self.name}"),
267267
}
268268
}
269269

270270
if self.estimate_cycle_length:
271-
self.param_names += [f"{self.name}_length"]
272-
self.param_info[f"{self.name}_length"] = {
271+
self.param_names += [f"length_{self.name}"]
272+
self.param_info[f"length_{self.name}"] = {
273273
"shape": () if self.k_endog == 1 else (self.k_endog,),
274274
"constraints": "Positive, non-zero",
275-
"dims": None if self.k_endog == 1 else f"{self.name}_endog",
275+
"dims": None if self.k_endog == 1 else f"endog_{self.name}",
276276
}
277277

278278
if self.dampen:
279-
self.param_names += [f"{self.name}_dampening_factor"]
280-
self.param_info[f"{self.name}_dampening_factor"] = {
279+
self.param_names += [f"dampening_factor_{self.name}"]
280+
self.param_info[f"dampening_factor_{self.name}"] = {
281281
"shape": () if self.k_endog == 1 else (self.k_endog,),
282282
"constraints": "0 < x ≤ 1",
283-
"dims": None if self.k_endog == 1 else f"{self.name}_endog",
283+
"dims": None if self.k_endog == 1 else (f"endog_{self.name}",),
284284
}
285285

286286
if self.innovations:
@@ -292,10 +292,10 @@ def populate_component_properties(self):
292292
"dims": None,
293293
}
294294
else:
295-
self.param_dims[f"sigma_{self.name}"] = (f"{self.name}_endog",)
295+
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)
296296
self.param_info[f"sigma_{self.name}"] = {
297297
"shape": (self.k_endog,),
298298
"constraints": "Positive",
299-
"dims": (f"{self.name}_endog",),
299+
"dims": (f"endog_{self.name}",),
300300
}
301301
self.shock_names = self.state_names.copy()

pymc_extras/statespace/models/structural/components/level_trend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,11 @@ def populate_component_properties(self):
197197
]
198198

199199
self.param_dims[f"sigma_{self.name}"] = (
200-
(f"{self.name}_shock",)
200+
(f"shock_{self.name}",)
201201
if k_endog == 1
202-
else (f"endog_{self.name}", f"{self.name}_shock")
202+
else (f"endog_{self.name}", f"shock_{self.name}")
203203
)
204-
self.coords[f"{self.name}_shock"] = base_shock_names
204+
self.coords[f"shock_{self.name}"] = base_shock_names
205205
self.param_info[f"sigma_{self.name}"] = {
206206
"shape": (k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
207207
"constraints": "Positive",

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -180,30 +180,30 @@ def populate_component_properties(self):
180180
for endog_name in self.observed_state_names
181181
for state_name in self.provided_state_names
182182
]
183-
self.param_names = [f"{self.name}_coefs"]
183+
self.param_names = [f"coefs_{self.name}"]
184184

185185
self.param_info = {
186-
f"{self.name}_coefs": {
186+
f"coefs_{self.name}": {
187187
"shape": (k_states,) if k_endog == 1 else (k_endog, k_states),
188188
"constraints": None,
189-
"dims": (f"{self.name}_state",)
189+
"dims": (f"state_{self.name}",)
190190
if k_endog == 1
191-
else (f"{self.name}_endog", f"{self.name}_state"),
191+
else (f"endog_{self.name}", f"state_{self.name}"),
192192
}
193193
}
194194

195195
self.param_dims = {
196-
f"{self.name}_coefs": (f"{self.name}_state",)
196+
f"coefs_{self.name}": (f"state_{self.name}",)
197197
if k_endog == 1
198-
else (f"{self.name}_endog", f"{self.name}_state")
198+
else (f"endog_{self.name}", f"state_{self.name}")
199199
}
200200

201201
self.coords = (
202-
{f"{self.name}_state": self.provided_state_names}
202+
{f"state_{self.name}": self.provided_state_names}
203203
if k_endog == 1
204204
else {
205-
f"{self.name}_endog": self.observed_state_names,
206-
f"{self.name}_state": self.provided_state_names,
205+
f"endog_{self.name}": self.observed_state_names,
206+
f"state_{self.name}": self.provided_state_names,
207207
}
208208
)
209209

@@ -238,7 +238,7 @@ def make_symbolic_graph(self) -> None:
238238
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
239239

240240
initial_states = self.make_and_register_variable(
241-
f"{self.name}_coefs", shape=(k_states,) if k_endog == 1 else (k_endog, k_states)
241+
f"coefs_{self.name}", shape=(k_states,) if k_endog == 1 else (k_endog, k_states)
242242
)
243243
self.ssm["initial_state", :] = initial_states.ravel()
244244

@@ -390,21 +390,21 @@ def populate_component_properties(self):
390390
k_states = self.k_states // k_endog
391391

392392
self.state_names = [
393-
f"{self.name}_{f}_{i}[{obs_state_name}]"
393+
f"{f}_{self.name}_{i}[{obs_state_name}]"
394394
for obs_state_name in self.observed_state_names
395395
for i in range(self.n)
396396
for f in ["Cos", "Sin"]
397397
]
398398
self.param_names = [f"{self.name}"]
399399

400-
self.param_dims = {self.name: (f"{self.name}_state",)}
400+
self.param_dims = {self.name: (f"state_{self.name}",)}
401401
self.param_info = {
402402
f"{self.name}": {
403403
"shape": (n_coefs,) if k_endog == 1 else (k_endog, n_coefs),
404404
"constraints": None,
405-
"dims": (f"{self.name}_state",)
405+
"dims": (f"state_{self.name}",)
406406
if k_endog == 1
407-
else (f"{self.name}_endog", f"{self.name}_state"),
407+
else (f"endog_{self.name}", f"state_{self.name}"),
408408
}
409409
}
410410

@@ -418,13 +418,13 @@ def populate_component_properties(self):
418418
],
419419
axis=0,
420420
)
421-
self.coords = {f"{self.name}_state": [self.state_names[i] for i in init_state_idx]}
421+
self.coords = {f"state_{self.name}": [self.state_names[i] for i in init_state_idx]}
422422

423423
if self.innovations:
424424
self.shock_names = self.state_names.copy()
425425
self.param_names += [f"sigma_{self.name}"]
426426
self.param_info[f"sigma_{self.name}"] = {
427427
"shape": () if k_endog == 1 else (k_endog, n_coefs),
428428
"constraints": "Positive",
429-
"dims": None if k_endog == 1 else (f"{self.name}_endog",),
429+
"dims": None if k_endog == 1 else (f"endog_{self.name}",),
430430
}

tests/statespace/core/test_statespace.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def pymc_mod_no_exog(ss_mod_no_exog, rng):
275275
P0 = pm.Deterministic(
276276
"P0", pt.eye(ss_mod_no_exog.k_states) * P0_sigma, dims=["state", "state_aux"]
277277
)
278-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
278+
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["shock_trend"])
279279
ss_mod_no_exog.build_statespace_graph(y)
280280

281281
return m
@@ -291,7 +291,7 @@ def pymc_mod_no_exog_mv(ss_mod_no_exog_mv, rng):
291291
P0 = pm.Deterministic(
292292
"P0", pt.eye(ss_mod_no_exog_mv.k_states) * P0_sigma, dims=["state", "state_aux"]
293293
)
294-
trend_sigma = pm.Exponential("sigma_trend", 1, dims=["endog_trend", "trend_shock"])
294+
trend_sigma = pm.Exponential("sigma_trend", 1, dims=["endog_trend", "shock_trend"])
295295
ss_mod_no_exog_mv.build_statespace_graph(y)
296296

297297
return m
@@ -311,7 +311,7 @@ def pymc_mod_no_exog_mv_dt(ss_mod_no_exog_mv, rng):
311311
P0 = pm.Deterministic(
312312
"P0", pt.eye(ss_mod_no_exog_mv.k_states) * P0_sigma, dims=["state", "state_aux"]
313313
)
314-
trend_sigma = pm.Exponential("sigma_trend", 1, dims=["endog_trend", "trend_shock"])
314+
trend_sigma = pm.Exponential("sigma_trend", 1, dims=["endog_trend", "shock_trend"])
315315
ss_mod_no_exog_mv.build_statespace_graph(y)
316316

317317
return m
@@ -331,7 +331,7 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):
331331
P0 = pm.Deterministic(
332332
"P0", pt.eye(ss_mod_no_exog_dt.k_states) * P0_sigma, dims=["state", "state_aux"]
333333
)
334-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
334+
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["shock_trend"])
335335
ss_mod_no_exog_dt.build_statespace_graph(y)
336336

337337
return m

tests/statespace/core/test_statespace_JAX.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def exog_pymc_mod(exog_ss_mod, rng):
6262
)
6363
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
6464

65-
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
65+
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["shock_trend"])
6666
exog_ss_mod.build_statespace_graph(y)
6767

6868
return m

0 commit comments

Comments
 (0)