Skip to content

Commit 4152d4a

Browse files
committed
Improve idata display of structural components
1 parent 24930b5 commit 4152d4a

File tree

4 files changed

+404
-6
lines changed

4 files changed

+404
-6
lines changed

pymc_extras/statespace/models/structural/core.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
join_tensors_by_dim_labels,
1919
make_default_coords,
2020
)
21+
from pymc_extras.statespace.utils.component_parsing import restructure_components_idata
2122
from pymc_extras.statespace.utils.constants import (
2223
ALL_STATE_AUX_DIM,
2324
ALL_STATE_DIM,
@@ -208,7 +209,7 @@ def __init__(
208209
self._component_info = component_info.copy()
209210

210211
self._name_to_variable = name_to_variable.copy()
211-
self._name_to_data = name_to_data.copy()
212+
self._name_to_data = name_to_data.copy() if name_to_data is not None else {}
212213

213214
self._exog_names = exog_names.copy()
214215
self._needs_exog_data = len(exog_names) > 0
@@ -350,20 +351,27 @@ def _get_subcomponent_names(self):
350351
result.extend([f"{name}[{comp_name}]" for comp_name in comp_names])
351352
return result
352353

353-
def extract_components_from_idata(self, idata: xr.Dataset) -> xr.Dataset:
354+
def extract_components_from_idata(
355+
self, idata: xr.Dataset, restructure: bool = False
356+
) -> xr.Dataset:
354357
r"""
355358
Extract interpretable hidden states from an InferenceData returned by a PyMCStateSpace sampling method
356359
357360
Parameters
358361
----------
359-
idata: Dataset
362+
idata : Dataset
360363
A Dataset object, returned by a PyMCStateSpace sampling method
364+
restructure : bool, default True
365+
Whether to restructure the state coordinates as a multi-index for easier component selection.
366+
When True, enables selections like `idata.sel(component='level')` and `idata.sel(observed='gdp')`.
367+
Particularly useful for multivariate models with multiple observed states.
361368
362369
Returns
363370
-------
364-
idata: Dataset
371+
idata : Dataset
365372
A Dataset object with hidden states transformed to represent only the "interpretable" subcomponents
366-
of the structural model.
373+
of the structural model. If `restructure=True`, the state coordinate will be a multi-index with
374+
levels ['component', 'observed'] for easier selection.
367375
368376
Notes
369377
-----
@@ -383,9 +391,12 @@ def extract_components_from_idata(self, idata: xr.Dataset) -> xr.Dataset:
383391
- :math:`\varepsilon_t` is the measurement error at time t
384392
385393
In state space form, some or all of these components are represented as linear combinations of other
386-
subcomponents, making interpretation of the outputs of the outputs difficult. The purpose of this function is
394+
subcomponents, making interpretation of the outputs difficult. The purpose of this function is
387395
to take the expended statespace representation and return a "reduced form" of only the components shown in
388396
equation (1).
397+
398+
When `restructure=True`, the returned dataset allows for easy component selection, especially for
399+
multivariate models with multiple observed states.
389400
"""
390401

391402
def _extract_and_transform_variable(idata, new_state_names):
@@ -423,6 +434,17 @@ def _extract_and_transform_variable(idata, new_state_names):
423434
for name in latent_names
424435
}
425436
)
437+
438+
if restructure:
439+
try:
440+
idata_new = restructure_components_idata(idata_new)
441+
except Exception as e:
442+
_log.warning(
443+
f"Failed to restructure components with multi-index: {e}. "
444+
"Returning dataset with original string-based state names. "
445+
"You can call restructure_components_idata() manually if needed."
446+
)
447+
426448
return idata_new
427449

428450

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .component_parsing import (
2+
create_component_multiindex,
3+
parse_component_state_name,
4+
restructure_components_idata,
5+
)
6+
7+
__all__ = [
8+
"create_component_multiindex",
9+
"parse_component_state_name",
10+
"restructure_components_idata",
11+
]
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
Parsing utilities for component state names in structural time series models.
3+
4+
This module provides functionality to parse complex state names like 'trend[level[observed_state]]'
5+
into structured multi-index coordinates that enable easy component and state selection.
6+
7+
NB: This is still a work in progress, and probably need to be expanded to more complex cases.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import re
13+
14+
from collections.abc import Sequence
15+
16+
import pandas as pd
17+
import xarray as xr
18+
19+
20+
def parse_component_state_name(state_name: str) -> tuple[str, str]:
21+
"""
22+
Parse a component state name into its constituent parts.
23+
24+
Extracts the actual interpretable state name and observed state from
25+
various component naming formats.
26+
27+
Parameters
28+
----------
29+
state_name : str
30+
The state name to parse, e.g., 'trend[level[observed_state]]' or 'ar[observed_state]'
31+
32+
Returns
33+
-------
34+
tuple[str, str]
35+
A tuple of (component, observed) where component is the interpretable component name
36+
and observed is the observed state name
37+
38+
Examples
39+
--------
40+
>>> parse_component_state_name('trend[level[chirac2]]')
41+
('level', 'chirac2')
42+
>>> parse_component_state_name('ar[macron]')
43+
('ar', 'macron')
44+
"""
45+
# Handle the nested bracket pattern: component[state[observed]]
46+
# For these, we want the inner state name (level, trend, etc.)
47+
# because the first level is redundant with the component name
48+
nested_pattern = r"^([^[]+)\[([^[]+)\[([^]]+)\]\]$"
49+
nested_match = re.match(nested_pattern, state_name)
50+
51+
if nested_match:
52+
# Return the inner state name and observed state
53+
return nested_match.group(2), nested_match.group(3)
54+
55+
# Handle the simple bracket pattern: component[observed]
56+
# For these, we want the component name directly
57+
simple_pattern = r"^([^[]+)\[([^]]+)\]$"
58+
simple_match = re.match(simple_pattern, state_name)
59+
60+
if simple_match:
61+
# Return the component name and observed state
62+
return simple_match.group(1), simple_match.group(2)
63+
64+
# If no pattern matches, treat the whole string as a state name
65+
# This is a fallback for edge cases
66+
return state_name, "default"
67+
68+
69+
def create_component_multiindex(
70+
state_names: Sequence[str], coord_name: str = "state"
71+
) -> xr.Coordinates:
72+
"""
73+
Create xarray coordinates with multi-index from component state names.
74+
75+
Parameters
76+
----------
77+
state_names : Sequence[str]
78+
List of state names to parse into multi-index
79+
coord_name : str, default "state"
80+
Name for the coordinate dimension to transform into a multi-index
81+
82+
Returns
83+
-------
84+
xr.Coordinates
85+
xarray coordinates with multi-index structure
86+
87+
Examples
88+
--------
89+
>>> state_names = ['trend[level[observed_state]]', 'trend[trend[observed_state]]', 'ar[observed_state]']
90+
>>> coords = create_component_multiindex(state_names)
91+
>>> coords.to_index().names
92+
['component', 'observed']
93+
>>> coords.to_index().values
94+
[('level', 'observed_state'), ('trend', 'observed_state'), ('ar', 'observed_state')]
95+
"""
96+
tuples = [parse_component_state_name(name) for name in state_names]
97+
midx = pd.MultiIndex.from_tuples(tuples, names=["component", "observed"])
98+
99+
return xr.Coordinates.from_pandas_multiindex(midx, dim=coord_name)
100+
101+
102+
def restructure_components_idata(idata: xr.Dataset) -> xr.Dataset:
103+
"""
104+
Restructure idata with multi-index coordinates for easier component selection.
105+
106+
Parameters
107+
----------
108+
idata : xr.Dataset
109+
Dataset with component state names as coordinates
110+
111+
Returns
112+
-------
113+
xr.Dataset
114+
Dataset with restructured multi-index coordinates
115+
116+
Examples
117+
--------
118+
>>> # After calling extract_components_from_idata from core.py
119+
>>> restructured = restructure_components_idata(components_idata)
120+
>>> # Now you can select by component or observed state
121+
>>> level_data = restructured.sel(component='level') # All level components
122+
>>> gdp_data = restructured.sel(observed='gdp') # All gdp data
123+
>>> level_gdp = restructured.sel(component='level', observed='gdp') # Specific combination
124+
"""
125+
# name of the coordinate containing state names
126+
# should be `state`, by default, as users don't access it directly
127+
# would need to be updated if we want to support custom names
128+
state_coord_name = "state"
129+
if state_coord_name not in idata.coords:
130+
raise ValueError(f"Coordinate '{state_coord_name}' not found in dataset")
131+
132+
state_names = idata.coords[state_coord_name].values
133+
mindex_coords = create_component_multiindex(state_names, state_coord_name)
134+
135+
return idata.assign_coords(mindex_coords)

0 commit comments

Comments
 (0)