Skip to content

Commit fbc61a1

Browse files
Delay dropping data names from states/coords until .build
1 parent 6debd23 commit fbc61a1

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

pymc_extras/statespace/models/structural/components/level_trend.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,7 @@ def populate_component_properties(self):
168168
self.param_names = [f"{self.name}_initial"]
169169
base_names = [name for name, mask in zip(name_slice, self._order_mask) if mask]
170170
self.state_names = [
171-
f"{name}[{obs_name}]" if k_endog > 1 else name
172-
for obs_name in self.observed_state_names
173-
for name in base_names
171+
f"{name}[{obs_name}]" for obs_name in self.observed_state_names for name in base_names
174172
]
175173
self.param_dims = {f"{self.name}_initial": (f"{self.name}_state",)}
176174
self.coords = {f"{self.name}_state": base_names}
@@ -193,7 +191,7 @@ def populate_component_properties(self):
193191
name for name, mask in zip(name_slice, self.innovations_order) if mask
194192
]
195193
self.shock_names = [
196-
f"{name}[{obs_name}]" if k_endog > 1 else name
194+
f"{name}[{obs_name}]"
197195
for obs_name in self.observed_state_names
198196
for name in shock_base_names
199197
]

pymc_extras/statespace/models/structural/core.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,19 @@ def __init__(
7676
param_names, param_dims, param_info = self._add_inital_state_cov_to_properties(
7777
param_names, param_dims, param_info, k_states
7878
)
79-
self._state_names = state_names.copy()
80-
self._data_names = data_names.copy()
81-
self._shock_names = shock_names.copy()
82-
self._param_names = param_names.copy()
83-
self._param_dims = param_dims.copy()
79+
80+
self._state_names = self._strip_data_names_if_unambiguous(state_names, k_endog)
81+
self._data_names = self._strip_data_names_if_unambiguous(data_names, k_endog)
82+
self._shock_names = self._strip_data_names_if_unambiguous(shock_names, k_endog)
83+
self._param_names = self._strip_data_names_if_unambiguous(param_names, k_endog)
84+
self._param_dims = param_dims
8485

8586
default_coords = make_default_coords(self)
8687
coords.update(default_coords)
8788

88-
self._coords = coords
89+
self._coords = {
90+
k: self._strip_data_names_if_unambiguous(v, k_endog) for k, v in coords.items()
91+
}
8992
self._param_info = param_info.copy()
9093
self._data_info = data_info.copy()
9194
self.measurement_error = measurement_error
@@ -122,6 +125,25 @@ def __init__(
122125
P0 = self.make_and_register_variable("P0", shape=(self.k_states, self.k_states))
123126
self.ssm["initial_state_cov"] = P0
124127

128+
def _strip_data_names_if_unambiguous(self, names: list[str], k_endog: int):
129+
"""
130+
State names from components should always be of the form name[data_name], in the case that the component is
131+
associated with multiple observed states. Not doing so leads to ambiguity -- we might have two level states,
132+
but which goes to which observed component? So we set `level[data_1]` and `level[data_2]`.
133+
134+
In cases where there is only one observed state (when k_endog == 1), we can strip the data part and just use
135+
the state name. This is a bit cleaner.
136+
"""
137+
if k_endog == 1:
138+
[data_name] = self.observed_states
139+
return [
140+
name.replace(f"[{data_name}]", "") if isinstance(name, str) else name
141+
for name in names
142+
]
143+
144+
else:
145+
return names
146+
125147
@staticmethod
126148
def _add_inital_state_cov_to_properties(param_names, param_dims, param_info, k_states):
127149
param_names += ["P0"]

tests/statespace/models/structural/components/test_level_trend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_add_level_trend_with_different_observed():
119119
assert mod.coords["grw_state"] == ["level"]
120120

121121
assert mod.state_names == ["level[data_1]", "trend[data_1]", "level[data_2]"]
122-
assert mod.shock_names == ["trend_shock[data_1]", "level_shock[data_2]"]
122+
assert mod.shock_names == ["trend[data_1]", "level[data_2]"]
123123

124124
Z, T, R = pytensor.function(
125125
[], [mod.ssm["design"], mod.ssm["transition"], mod.ssm["selection"]], mode="FAST_COMPILE"

0 commit comments

Comments
 (0)