Skip to content

Commit 3c5124d

Browse files
LevelTrend state/shock names depend on component name
1 parent 0c4590e commit 3c5124d

File tree

3 files changed

+126
-43
lines changed

3 files changed

+126
-43
lines changed

pymc_extras/statespace/models/structural/components/level_trend.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
2-
3-
from scipy import linalg
2+
import pytensor.tensor as pt
43

54
from pymc_extras.statespace.models.structural.core import Component
65
from pymc_extras.statespace.models.structural.utils import order_to_mask
@@ -114,7 +113,7 @@ def __init__(
114113
self,
115114
order: int | list[int] = 2,
116115
innovations_order: int | list[int] | None = None,
117-
name: str = "LevelTrend",
116+
name: str = "level_trend",
118117
observed_state_names: list[str] | None = None,
119118
):
120119
if innovations_order is None:
@@ -166,35 +165,46 @@ def populate_component_properties(self):
166165
k_posdef = self.k_posdef // k_endog
167166

168167
name_slice = POSITION_DERIVATIVE_NAMES[:k_states]
169-
self.param_names = ["initial_trend"]
168+
self.param_names = [f"{self.name}_initial"]
170169
base_names = [name for name, mask in zip(name_slice, self._order_mask) if mask]
171170
self.state_names = [
172-
f"{name}[{obs_name}]" for obs_name in self.observed_state_names for name in base_names
171+
f"{name}[{obs_name}]" if k_endog > 1 else name
172+
for obs_name in self.observed_state_names
173+
for name in base_names
173174
]
174-
self.param_dims = {"initial_trend": ("trend_state",)}
175-
self.coords = {"trend_state": base_names}
175+
self.param_dims = {f"{self.name}_initial": (f"{self.name}_state",)}
176+
self.coords = {f"{self.name}_state": base_names}
176177

177178
if k_endog > 1:
178-
self.param_dims["trend_state"] = (
179-
"trend_endog",
180-
"trend_state",
179+
self.param_dims[f"{self.name}_state"] = (
180+
f"{self.name}_endog",
181+
f"{self.name}_state",
181182
)
182-
self.param_dims = {"initial_trend": ("trend_endog", "trend_state")}
183-
self.coords["trend_endog"] = self.observed_state_names
183+
self.param_dims = {f"{self.name}_initial": (f"{self.name}_endog", f"{self.name}_state")}
184+
self.coords[f"{self.name}_endog"] = self.observed_state_names
184185

185186
shape = (k_endog, k_states) if k_endog > 1 else (k_states,)
186-
self.param_info = {"initial_trend": {"shape": shape, "constraints": None}}
187+
self.param_info = {f"{self.name}_initial": {"shape": shape, "constraints": None}}
187188

188189
if self.k_posdef > 0:
189-
self.param_names += ["sigma_trend"]
190-
self.shock_names = [
190+
self.param_names += [f"{self.name}_sigma"]
191+
192+
shock_base_names = [
191193
name for name, mask in zip(name_slice, self.innovations_order) if mask
192194
]
193-
self.param_dims["sigma_trend"] = (
194-
("trend_shock",) if k_endog == 1 else ("trend_endog", "trend_shock")
195+
self.shock_names = [
196+
f"{name}[{obs_name}]" if k_endog > 1 else name
197+
for obs_name in self.observed_state_names
198+
for name in shock_base_names
199+
]
200+
201+
self.param_dims[f"{self.name}_sigma"] = (
202+
(f"{self.name}_shock",)
203+
if k_endog == 1
204+
else (f"{self.name}_endog", f"{self.name}_shock")
195205
)
196-
self.coords["trend_shock"] = self.shock_names
197-
self.param_info["sigma_trend"] = {
206+
self.coords[f"{self.name}_shock"] = self.shock_names
207+
self.param_info[f"{self.name}_sigma"] = {
198208
"shape": (k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
199209
"constraints": "Positive",
200210
}
@@ -208,28 +218,34 @@ def make_symbolic_graph(self) -> None:
208218
k_posdef = self.k_posdef // k_endog
209219

210220
initial_trend = self.make_and_register_variable(
211-
"initial_trend",
221+
f"{self.name}_initial",
212222
shape=(k_states,) if k_endog == 1 else (k_endog, k_states),
213223
)
214224
self.ssm["initial_state", :] = initial_trend.ravel()
215225

216-
triu_idx = np.triu_indices(k_states)
217-
T = np.zeros((k_states, k_states))
218-
T[triu_idx[0], triu_idx[1]] = 1
226+
triu_idx = pt.triu_indices(k_states)
227+
T = pt.zeros((k_states, k_states))[triu_idx[0], triu_idx[1]].set(1)
219228

220-
self.ssm["transition"] = linalg.block_diag(*[T for _ in range(k_endog)])
229+
self.ssm["transition", :, :] = pt.specify_shape(
230+
pt.linalg.block_diag(*[T for _ in range(k_endog)]), (self.k_states, self.k_states)
231+
)
221232

222233
R = np.eye(k_states)
223234
R = R[:, self.innovations_order]
224235

225-
self.ssm["selection", :, :] = linalg.block_diag(*[R for _ in range(k_endog)])
236+
self.ssm["selection", :, :] = pt.specify_shape(
237+
pt.linalg.block_diag(*[R for _ in range(k_endog)]), (self.k_states, self.k_posdef)
238+
)
226239

227240
Z = np.array([1.0] + [0.0] * (k_states - 1)).reshape((1, -1))
228-
self.ssm["design"] = linalg.block_diag(*[Z for _ in range(k_endog)])
241+
242+
self.ssm["design", :, :] = pt.specify_shape(
243+
pt.linalg.block_diag(*[Z for _ in range(k_endog)]), (self.k_endog, self.k_states)
244+
)
229245

230246
if k_posdef > 0:
231247
sigma_trend = self.make_and_register_variable(
232-
"sigma_trend",
248+
f"{self.name}_sigma",
233249
shape=(k_posdef,) if k_endog == 1 else (k_endog, k_posdef),
234250
)
235251
diag_idx = np.diag_indices(k_posdef * k_endog)

tests/statespace/models/structural/components/test_level_trend.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytensor
23

34
from numpy.testing import assert_allclose
45
from pytensor import config
@@ -13,15 +14,15 @@
1314

1415
def test_level_trend_model(rng):
1516
mod = st.LevelTrendComponent(order=2, innovations_order=0)
16-
params = {"initial_trend": [0.0, 1.0]}
17+
params = {"level_trend_initial": [0.0, 1.0]}
1718
x, y = simulate_from_numpy_model(mod, rng, params)
1819

1920
assert_allclose(np.diff(y), 1, atol=ATOL, rtol=RTOL)
2021

2122
# Check coords
2223
mod = mod.build(verbose=False)
2324
_assert_basic_coords_correct(mod)
24-
assert mod.coords["trend_state"] == ["level", "trend"]
25+
assert mod.coords["level_trend_state"] == ["level", "trend"]
2526

2627

2728
def test_level_trend_multiple_observed_construction():
@@ -33,12 +34,22 @@ def test_level_trend_multiple_observed_construction():
3334
assert mod.k_states == 6
3435
assert mod.k_posdef == 3
3536

36-
assert mod.coords["trend_state"] == ["level", "trend"]
37-
assert mod.coords["trend_endog"] == ["data_1", "data_2", "data_3"]
37+
assert mod.coords["level_trend_state"] == ["level", "trend"]
38+
assert mod.coords["level_trend_endog"] == ["data_1", "data_2", "data_3"]
3839

39-
Z = mod.ssm["design"].eval()
40-
T = mod.ssm["transition"].eval()
41-
R = mod.ssm["selection"].eval()
40+
assert mod.state_names == [
41+
"level[data_1]",
42+
"trend[data_1]",
43+
"level[data_2]",
44+
"trend[data_2]",
45+
"level[data_3]",
46+
"trend[data_3]",
47+
]
48+
assert mod.shock_names == ["level_shock[data_1]", "level_shock[data_2]", "level_shock[data_3]"]
49+
50+
Z, T, R = pytensor.function(
51+
[], [mod.ssm["design"], mod.ssm["transition"], mod.ssm["selection"]], mode="FAST_COMPILE"
52+
)()
4253

4354
np.testing.assert_allclose(
4455
Z,
@@ -84,8 +95,64 @@ def test_level_trend_multiple_observed(rng):
8495
mod = st.LevelTrendComponent(
8596
order=2, innovations_order=0, observed_state_names=["data_1", "data_2", "data_3"]
8697
)
87-
params = {"initial_trend": np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0]])}
98+
params = {"level_trend_initial": np.array([[0.0, 1.0], [0.0, 2.0], [0.0, 3.0]])}
8899

89100
x, y = simulate_from_numpy_model(mod, rng, params)
90101
assert (np.diff(y, axis=0) == np.array([[1.0, 2.0, 3.0]])).all().all()
91102
assert (np.diff(x, axis=0) == np.array([[1.0, 0.0, 2.0, 0.0, 3.0, 0.0]])).all().all()
103+
104+
105+
def test_add_level_trend_with_different_observed():
106+
mod_1 = st.LevelTrendComponent(
107+
name="ll", order=2, innovations_order=[0, 1], observed_state_names=["data_1"]
108+
)
109+
mod_2 = st.LevelTrendComponent(
110+
name="grw", order=1, innovations_order=[1], observed_state_names=["data_2"]
111+
)
112+
113+
mod = (mod_1 + mod_2).build(verbose=False)
114+
assert mod.k_endog == 2
115+
assert mod.k_states == 3
116+
assert mod.k_posdef == 2
117+
118+
assert mod.coords["ll_state"] == ["level", "trend"]
119+
assert mod.coords["grw_state"] == ["level"]
120+
121+
assert mod.state_names == ["level[data_1]", "trend[data_1]", "level[data_2]"]
122+
assert mod.shock_names == ["trend_shock[data_1]", "level_shock[data_2]"]
123+
124+
Z, T, R = pytensor.function(
125+
[], [mod.ssm["design"], mod.ssm["transition"], mod.ssm["selection"]], mode="FAST_COMPILE"
126+
)()
127+
128+
np.testing.assert_allclose(
129+
Z,
130+
np.array(
131+
[
132+
[1.0, 0.0, 0.0],
133+
[0.0, 0.0, 1.0],
134+
]
135+
),
136+
)
137+
138+
np.testing.assert_allclose(
139+
T,
140+
np.array(
141+
[
142+
[1.0, 1.0, 0.0],
143+
[0.0, 1.0, 0.0],
144+
[0.0, 0.0, 1.0],
145+
]
146+
),
147+
)
148+
149+
np.testing.assert_allclose(
150+
R,
151+
np.array(
152+
[
153+
[0.0, 0.0],
154+
[1.0, 0.0],
155+
[0.0, 1.0],
156+
]
157+
),
158+
)

tests/statespace/models/structural/test_against_statsmodels.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def create_structural_model_and_equivalent_statsmodel(
220220

221221
if level:
222222
level_trend_order[0] = 1
223-
expected_coords["trend_state"] += [
223+
expected_coords["level_state"] += [
224224
"level",
225225
]
226226
expected_coords[ALL_STATE_DIM] += [
@@ -231,7 +231,7 @@ def create_structural_model_and_equivalent_statsmodel(
231231
]
232232
if stochastic_level:
233233
level_trend_innov_order[0] = 1
234-
expected_coords["trend_shock"] += ["level"]
234+
expected_coords["level_shock"] += ["level"]
235235
expected_coords[SHOCK_DIM] += [
236236
"level",
237237
]
@@ -241,7 +241,7 @@ def create_structural_model_and_equivalent_statsmodel(
241241

242242
if trend:
243243
level_trend_order[1] = 1
244-
expected_coords["trend_state"] += [
244+
expected_coords["level_state"] += [
245245
"trend",
246246
]
247247
expected_coords[ALL_STATE_DIM] += [
@@ -253,12 +253,12 @@ def create_structural_model_and_equivalent_statsmodel(
253253

254254
if stochastic_trend:
255255
level_trend_innov_order[1] = 1
256-
expected_coords["trend_shock"] += ["trend"]
256+
expected_coords["level_shock"] += ["trend"]
257257
expected_coords[SHOCK_DIM] += ["trend"]
258258
expected_coords[SHOCK_AUX_DIM] += ["trend"]
259259

260260
if level or trend:
261-
expected_param_dims["initial_trend"] += ("trend_state",)
261+
expected_param_dims["level_initial"] += ("level_state",)
262262
level_value = np.where(
263263
level_trend_order,
264264
rng.normal(
@@ -272,13 +272,13 @@ def create_structural_model_and_equivalent_statsmodel(
272272
max_order = np.flatnonzero(level_value)[-1].item() + 1
273273
level_trend_order = level_trend_order[:max_order]
274274

275-
params["initial_trend"] = level_value[:max_order]
275+
params["level_initial"] = level_value[:max_order]
276276
sm_init["level"] = level_value[0]
277277
sm_init["trend"] = level_value[1]
278278

279279
if sum(level_trend_innov_order) > 0:
280-
expected_param_dims["sigma_trend"] += ("trend_shock",)
281-
params["sigma_trend"] = np.sqrt(sigma_level_value2)
280+
expected_param_dims["level_sigma"] += ("level_shock",)
281+
params["level_sigma"] = np.sqrt(sigma_level_value2)
282282

283283
sigma_level_value = sigma_level_value2.tolist()
284284
if stochastic_level:

0 commit comments

Comments
 (0)