diff --git a/pymc_extras/statespace/models/structural/components/seasonality.py b/pymc_extras/statespace/models/structural/components/seasonality.py index 05e7d41a..e0d9b4ff 100644 --- a/pymc_extras/statespace/models/structural/components/seasonality.py +++ b/pymc_extras/statespace/models/structural/components/seasonality.py @@ -235,6 +235,7 @@ def __init__( state_names: list | None = None, remove_first_state: bool = True, observed_state_names: list[str] | None = None, + share_states: bool = False, ): if observed_state_names is None: observed_state_names = ["data"] @@ -261,6 +262,7 @@ def __init__( ) state_names = state_names.copy() + self.share_states = share_states self.innovations = innovations self.duration = duration self.remove_first_state = remove_first_state @@ -281,44 +283,53 @@ def __init__( super().__init__( name=name, k_endog=k_endog, - k_states=k_states * k_endog, - k_posdef=k_posdef * k_endog, + k_states=k_states if share_states else k_states * k_endog, + k_posdef=k_posdef if share_states else k_posdef * k_endog, observed_state_names=observed_state_names, measurement_error=False, combine_hidden_states=True, - obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog), + obs_state_idxs=np.tile( + np.array([1.0] + [0.0] * (k_states - 1)), 1 if share_states else k_endog + ), + share_states=share_states, ) def populate_component_properties(self): - k_states = self.k_states // self.k_endog k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog + k_states = self.k_states // k_endog_effective - self.state_names = [ - f"{state_name}[{endog_name}]" - for endog_name in self.observed_state_names - for state_name in self.provided_state_names - ] + if self.share_states: + self.state_names = [ + f"{state_name}[{self.name}_shared]" for state_name in self.provided_state_names + ] + else: + self.state_names = [ + f"{state_name}[{endog_name}]" + for endog_name in self.observed_state_names + for state_name in self.provided_state_names + ] self.param_names = [f"params_{self.name}"] self.param_info = { f"params_{self.name}": { - "shape": (k_states,) if k_endog == 1 else (k_endog, k_states), + "shape": (k_states,) if k_endog_effective == 1 else (k_endog_effective, k_states), "constraints": None, "dims": (f"state_{self.name}",) - if k_endog == 1 + if k_endog_effective == 1 else (f"endog_{self.name}", f"state_{self.name}"), } } self.param_dims = { f"params_{self.name}": (f"state_{self.name}",) - if k_endog == 1 + if k_endog_effective == 1 else (f"endog_{self.name}", f"state_{self.name}") } self.coords = ( {f"state_{self.name}": self.provided_state_names} - if k_endog == 1 + if k_endog_effective == 1 else { f"endog_{self.name}": self.observed_state_names, f"state_{self.name}": self.provided_state_names, @@ -327,21 +338,26 @@ def populate_component_properties(self): if self.innovations: self.param_names += [f"sigma_{self.name}"] - self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names] self.param_info[f"sigma_{self.name}"] = { - "shape": () if k_endog == 1 else (k_endog,), + "shape": () if k_endog_effective == 1 else (k_endog_effective,), "constraints": "Positive", - "dims": None if k_endog == 1 else (f"endog_{self.name}",), + "dims": None if k_endog_effective == 1 else (f"endog_{self.name}",), } - if k_endog > 1: + if k_endog_effective > 1: self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",) + if self.share_states: + self.shock_names = [f"{self.name}[shared]"] + else: + self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names] + def make_symbolic_graph(self) -> None: - k_states = self.k_states // self.k_endog + k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog + k_states = self.k_states // k_endog_effective duration = self.duration k_unique_states = k_states // duration - k_posdef = self.k_posdef // self.k_endog - k_endog = self.k_endog + k_posdef = self.k_posdef // k_endog_effective if self.remove_first_state: # In this case, parameters are normalized to sum to zero, so the current state is the negative sum of @@ -373,16 +389,18 @@ def make_symbolic_graph(self) -> None: T = pt.eye(k_states, k=1) T = pt.set_subtensor(T[-1, 0], 1) - self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)]) + self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog_effective)]) Z = pt.zeros((1, k_states))[0, 0].set(1) - self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)]) + self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog_effective)]) initial_states = self.make_and_register_variable( f"params_{self.name}", - shape=(k_unique_states,) if k_endog == 1 else (k_endog, k_unique_states), + shape=(k_unique_states,) + if k_endog_effective == 1 + else (k_endog_effective, k_unique_states), ) - if k_endog == 1: + if k_endog_effective == 1: self.ssm["initial_state", :] = pt.extra_ops.repeat(initial_states, duration, axis=0) else: self.ssm["initial_state", :] = pt.extra_ops.repeat( @@ -391,11 +409,11 @@ def make_symbolic_graph(self) -> None: if self.innovations: R = pt.zeros((k_states, k_posdef))[0, 0].set(1.0) - self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog)]) + self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog_effective)]) season_sigma = self.make_and_register_variable( - f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,) + f"sigma_{self.name}", shape=() if k_endog_effective == 1 else (k_endog_effective,) ) - cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog)) + cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog_effective)) self.ssm[cov_idx] = season_sigma**2 diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index 95bc957c..4c82c5d6 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -18,6 +18,7 @@ join_tensors_by_dim_labels, make_default_coords, ) +from pymc_extras.statespace.utils.component_parsing import restructure_components_idata from pymc_extras.statespace.utils.constants import ( ALL_STATE_AUX_DIM, ALL_STATE_DIM, @@ -208,7 +209,7 @@ def __init__( self._component_info = component_info.copy() self._name_to_variable = name_to_variable.copy() - self._name_to_data = name_to_data.copy() + self._name_to_data = name_to_data.copy() if name_to_data is not None else {} self._exog_names = exog_names.copy() self._needs_exog_data = len(exog_names) > 0 @@ -318,9 +319,18 @@ def _hidden_states_from_data(self, data): if info[name]["combine_hidden_states"]: sum_idx_joined = np.flatnonzero(obs_idx) - sum_idx_split = np.split(sum_idx_joined, info[name]["k_endog"]) - for sum_idx in sum_idx_split: - result.append(X[..., sum_idx].sum(axis=-1)[..., None]) + k_endog = info[name]["k_endog"] + + if info[name]["share_states"]: + # sum once and replicate for each endogenous variable + shared_sum = X[..., sum_idx_joined].sum(axis=-1)[..., None] + for _ in range(k_endog): + result.append(shared_sum) + else: + # states are separate + sum_idx_split = np.split(sum_idx_joined, k_endog) + for sum_idx in sum_idx_split: + result.append(X[..., sum_idx].sum(axis=-1)[..., None]) else: n_components = len(self.state_names[s]) for j in range(n_components): @@ -350,20 +360,27 @@ def _get_subcomponent_names(self): result.extend([f"{name}[{comp_name}]" for comp_name in comp_names]) return result - def extract_components_from_idata(self, idata: xr.Dataset) -> xr.Dataset: + def extract_components_from_idata( + self, idata: xr.Dataset, restructure: bool = False + ) -> xr.Dataset: r""" Extract interpretable hidden states from an InferenceData returned by a PyMCStateSpace sampling method Parameters ---------- - idata: Dataset + idata : Dataset A Dataset object, returned by a PyMCStateSpace sampling method + restructure : bool, default False + Whether to restructure the state coordinates as a multi-index for easier component selection. + When True, enables selections like `idata.sel(component='level')` and `idata.sel(observed='gdp')`. + Particularly useful for multivariate models with multiple observed states. Returns ------- - idata: Dataset + idata : Dataset A Dataset object with hidden states transformed to represent only the "interpretable" subcomponents - of the structural model. + of the structural model. If `restructure=True`, the state coordinate will be a multi-index with + levels ['component', 'observed'] for easier selection. Notes ----- @@ -383,9 +400,12 @@ def extract_components_from_idata(self, idata: xr.Dataset) -> xr.Dataset: - :math:`\varepsilon_t` is the measurement error at time t In state space form, some or all of these components are represented as linear combinations of other - subcomponents, making interpretation of the outputs of the outputs difficult. The purpose of this function is + subcomponents, making interpretation of the outputs difficult. The purpose of this function is to take the expended statespace representation and return a "reduced form" of only the components shown in equation (1). + + When `restructure=True`, the returned dataset allows for easy component selection, especially for + multivariate models with multiple observed states. """ def _extract_and_transform_variable(idata, new_state_names): @@ -423,6 +443,17 @@ def _extract_and_transform_variable(idata, new_state_names): for name in latent_names } ) + + if restructure: + try: + idata_new = restructure_components_idata(idata_new) + except Exception as e: + _log.warning( + f"Failed to restructure components with multi-index: {e}. " + "Returning dataset with original string-based state names. " + "You can call restructure_components_idata() manually if needed." + ) + return idata_new @@ -471,6 +502,10 @@ class Component: obs_state_idxs : np.ndarray | None, optional Indices indicating which states contribute to observed variables. If None, defaults to None. + share_states : bool, optional + Whether states are shared across multiple endogenous variables in multivariate + models. When True, the same latent states affect all observed variables. + Default is False. Examples -------- @@ -512,10 +547,12 @@ def __init__( combine_hidden_states=True, component_from_sum=False, obs_state_idxs=None, + share_states: bool = False, ): self.name = name self.k_endog = k_endog self.k_states = k_states + self.share_states = share_states self.k_posdef = k_posdef self.measurement_error = measurement_error @@ -557,6 +594,7 @@ def __init__( "observed_state_names": self.observed_state_names, "combine_hidden_states": combine_hidden_states, "obs_state_idx": obs_state_idxs, + "share_states": self.share_states, } } diff --git a/pymc_extras/statespace/utils/__init__.py b/pymc_extras/statespace/utils/__init__.py index e69de29b..0686ace6 100644 --- a/pymc_extras/statespace/utils/__init__.py +++ b/pymc_extras/statespace/utils/__init__.py @@ -0,0 +1,11 @@ +from .component_parsing import ( + create_component_multiindex, + parse_component_state_name, + restructure_components_idata, +) + +__all__ = [ + "create_component_multiindex", + "parse_component_state_name", + "restructure_components_idata", +] diff --git a/pymc_extras/statespace/utils/component_parsing.py b/pymc_extras/statespace/utils/component_parsing.py new file mode 100644 index 00000000..39a0b221 --- /dev/null +++ b/pymc_extras/statespace/utils/component_parsing.py @@ -0,0 +1,135 @@ +""" +Parsing utilities for component state names in structural time series models. + +This module provides functionality to parse complex state names like 'trend[level[observed_state]]' +into structured multi-index coordinates that enable easy component and state selection. + +NB: This is still a work in progress, and probably need to be expanded to more complex cases. +""" + +from __future__ import annotations + +import re + +from collections.abc import Sequence + +import pandas as pd +import xarray as xr + + +def parse_component_state_name(state_name: str) -> tuple[str, str]: + """ + Parse a component state name into its constituent parts. + + Extracts the actual interpretable state name and observed state from + various component naming formats. + + Parameters + ---------- + state_name : str + The state name to parse, e.g., 'trend[level[observed_state]]' or 'ar[observed_state]' + + Returns + ------- + tuple[str, str] + A tuple of (component, observed) where component is the interpretable component name + and observed is the observed state name + + Examples + -------- + >>> parse_component_state_name('trend[level[chirac2]]') + ('level', 'chirac2') + >>> parse_component_state_name('ar[macron]') + ('ar', 'macron') + """ + # Handle the nested bracket pattern: component[state[observed]] + # For these, we want the inner state name (level, trend, etc.) + # because the first level is redundant with the component name + nested_pattern = r"^([^[]+)\[([^[]+)\[([^]]+)\]\]$" + nested_match = re.match(nested_pattern, state_name) + + if nested_match: + # Return the inner state name and observed state + return nested_match.group(2), nested_match.group(3) + + # Handle the simple bracket pattern: component[observed] + # For these, we want the component name directly + simple_pattern = r"^([^[]+)\[([^]]+)\]$" + simple_match = re.match(simple_pattern, state_name) + + if simple_match: + # Return the component name and observed state + return simple_match.group(1), simple_match.group(2) + + # If no pattern matches, treat the whole string as a state name + # This is a fallback for edge cases + return state_name, "default" + + +def create_component_multiindex( + state_names: Sequence[str], coord_name: str = "state" +) -> xr.Coordinates: + """ + Create xarray coordinates with multi-index from component state names. + + Parameters + ---------- + state_names : Sequence[str] + List of state names to parse into multi-index + coord_name : str, default "state" + Name for the coordinate dimension to transform into a multi-index + + Returns + ------- + xr.Coordinates + xarray coordinates with multi-index structure + + Examples + -------- + >>> state_names = ['trend[level[observed_state]]', 'trend[trend[observed_state]]', 'ar[observed_state]'] + >>> coords = create_component_multiindex(state_names) + >>> coords.to_index().names + ['component', 'observed'] + >>> coords.to_index().values + [('level', 'observed_state'), ('trend', 'observed_state'), ('ar', 'observed_state')] + """ + tuples = [parse_component_state_name(name) for name in state_names] + midx = pd.MultiIndex.from_tuples(tuples, names=["component", "observed"]) + + return xr.Coordinates.from_pandas_multiindex(midx, dim=coord_name) + + +def restructure_components_idata(idata: xr.Dataset) -> xr.Dataset: + """ + Restructure idata with multi-index coordinates for easier component selection. + + Parameters + ---------- + idata : xr.Dataset + Dataset with component state names as coordinates + + Returns + ------- + xr.Dataset + Dataset with restructured multi-index coordinates + + Examples + -------- + >>> # After calling extract_components_from_idata from core.py + >>> restructured = restructure_components_idata(components_idata) + >>> # Now you can select by component or observed state + >>> level_data = restructured.sel(component='level') # All level components + >>> gdp_data = restructured.sel(observed='gdp') # All gdp data + >>> level_gdp = restructured.sel(component='level', observed='gdp') # Specific combination + """ + # name of the coordinate containing state names + # should be `state`, by default, as users don't access it directly + # would need to be updated if we want to support custom names + state_coord_name = "state" + if state_coord_name not in idata.coords: + raise ValueError(f"Coordinate '{state_coord_name}' not found in dataset") + + state_names = idata.coords[state_coord_name].values + mindex_coords = create_component_multiindex(state_names, state_coord_name) + + return idata.assign_coords(mindex_coords) diff --git a/tests/statespace/utils/test_component_parsing.py b/tests/statespace/utils/test_component_parsing.py new file mode 100644 index 00000000..7ab6d381 --- /dev/null +++ b/tests/statespace/utils/test_component_parsing.py @@ -0,0 +1,230 @@ +""" +Tests for component state name parsing utilities. +""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from pymc_extras.statespace.utils.component_parsing import ( + create_component_multiindex, + parse_component_state_name, + restructure_components_idata, +) + + +class TestParseComponentStateName: + """Test the core parsing function for component state names.""" + + def test_nested_pattern(self): + """Test parsing of nested bracket patterns like 'component[state[observed]]'.""" + result = parse_component_state_name("trend[level[chirac2]]") + assert result == ("level", "chirac2") + + result = parse_component_state_name("seasonal[coef_0[macron]]") + assert result == ("coef_0", "macron") + + def test_simple_pattern(self): + """Test parsing of simple bracket patterns like 'component[observed]'.""" + result = parse_component_state_name("ar[macron]") + assert result == ("ar", "macron") + + result = parse_component_state_name("measurement_error[hollande]") + assert result == ("measurement_error", "hollande") + + def test_complex_component_names(self): + """Test parsing with complex component names that might include special characters.""" + # Test with underscores + result = parse_component_state_name("level_trend[level[data_1]]") + assert result == ("level", "data_1") + + # Test with numbers + result = parse_component_state_name("ar2[lag1[series_1]]") + assert result == ("lag1", "series_1") + + def test_complex_observed_names(self): + """Test parsing with complex observed variable names.""" + result = parse_component_state_name("trend[level[data_var_1]]") + assert result == ("level", "data_var_1") + + result = parse_component_state_name("seasonal[coef[obs_2024_Q1]]") + assert result == ("coef", "obs_2024_Q1") + + def test_fallback_pattern(self): + """Test the fallback behavior for unusual patterns.""" + result = parse_component_state_name("simple_state_name") + assert result == ("simple_state_name", "default") + + result = parse_component_state_name("no_brackets") + assert result == ("no_brackets", "default") + + def test_edge_cases(self): + """Test edge cases and malformed inputs.""" + # Empty brackets - should fall back to treating as simple component name + result = parse_component_state_name("component[]") + assert result == ("component[]", "default") + + # Mismatched brackets - should be parsed as simple pattern component[state[observed] + result = parse_component_state_name("component[state[observed]") + assert result == ("component", "state[observed") + + +class TestCreateComponentMultiindex: + """Test the multi-index creation functionality.""" + + def test_basic_multiindex_creation(self): + """Test creating a multi-index from basic state names.""" + state_names = [ + "trend[level[chirac2]]", + "trend[trend[chirac2]]", + "ar[chirac2]", + "trend[level[macron]]", + "ar[macron]", + ] + + coords = create_component_multiindex(state_names) + + index = coords.to_index() + assert isinstance(index, pd.MultiIndex) + assert index.names == ["component", "observed"] + assert "state" in coords.dims + + def test_custom_coord_name(self): + """Test creating multi-index with custom coordinate name.""" + state_names = ["trend[level[data]]", "ar[data]"] + coords = create_component_multiindex(state_names, coord_name="custom_state") + + assert "custom_state" in coords.dims + + def test_mixed_patterns(self): + """Test with a mix of nested and simple patterns.""" + state_names = [ + "trend[level[obs1]]", + "ar[obs1]", + "seasonal[coef_1[obs2]]", + "measurement_error[obs2]", + ] + + coords = create_component_multiindex(state_names) + index = coords.to_index() + + # check we get right structure + expected_tuples = [ + ("level", "obs1"), + ("ar", "obs1"), + ("coef_1", "obs2"), + ("measurement_error", "obs2"), + ] + + for i, expected in enumerate(expected_tuples): + assert index[i] == expected + + def test_empty_input(self): + """Test with empty state names list.""" + coords = create_component_multiindex([]) + index = coords.to_index() + assert len(index) == 0 + assert index.names == ["component", "observed"] + + +class TestRestructureComponentsIdata: + """Test the idata restructuring functionality.""" + + @staticmethod + def create_sample_idata(state_names): + n_chains, n_draws, n_time, n_states = 2, 100, 50, len(state_names) + data = np.random.normal(size=(n_chains, n_draws, n_time, n_states)) + return xr.Dataset( + { + "filtered_latent": xr.DataArray( + data, + dims=["chain", "draw", "time", "state"], + coords={ + "chain": range(n_chains), + "draw": range(n_draws), + "time": range(n_time), + "state": state_names, + }, + ) + } + ) + + def test_basic_restructuring(self): + state_names = [ + "trend[level[chirac2]]", + "trend[trend[chirac2]]", + "ar[chirac2]", + "trend[level[macron]]", + "ar[macron]", + ] + + idata = self.create_sample_idata(state_names) + restructured = restructure_components_idata(idata) + + state_index = restructured.coords["state"].to_index() + assert isinstance(state_index, pd.MultiIndex) + assert state_index.names == ["component", "observed"] + + # check we can select by component + level_data = restructured.sel(component="level") + assert "level" in restructured.coords["component"].values + assert ( + "ar" in restructured.coords["component"].values + ) # ar should be here from ar[chirac2] and ar[macron] + + # check we can select by observed state + chirac_data = restructured.sel(observed="chirac2") + assert "chirac2" in restructured.coords["observed"].values + + def test_missing_coordinate_error(self): + """Test error handling when coordinate doesn't exist.""" + idata = xr.Dataset({"data": xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [1, 2, 3]})}) + + with pytest.raises(ValueError, match="Coordinate 'state' not found"): + restructure_components_idata(idata) + + +class TestIntegrationScenarios: + def test_real_world_example(self): + state_names = [ + "trend[level[chirac2]]", + "trend[trend[chirac2]]", + "trend[level[sarkozy]]", + "trend[trend[sarkozy]]", + "trend[level[hollande]]", + "trend[trend[hollande]]", + "trend[level[macron]]", + "trend[trend[macron]]", + "trend[level[macron2]]", + "trend[trend[macron2]]", + "ar[chirac2]", + "ar[sarkozy]", + "ar[hollande]", + "ar[macron]", + "ar[macron2]", + ] + n_chains, n_draws, n_time = 4, 500, 100 + data = np.random.normal(size=(n_chains, n_draws, n_time, len(state_names))) + idata = xr.Dataset( + { + "filtered_latent": xr.DataArray( + data, dims=["chain", "draw", "time", "state"], coords={"state": state_names} + ) + } + ) + restructured = restructure_components_idata(idata) + + macron_data = restructured.sel(observed="macron") + assert macron_data.filtered_latent.shape == ( + 4, + 500, + 100, + 3, + ) # 3 because trend level, trend trend, ar + + ar_data = restructured.sel(component="ar") + assert ar_data.filtered_latent.shape == (4, 500, 100, 5) # 5 observed states + + level_macron = restructured.sel(component="level", observed="macron") + assert level_macron.filtered_latent.shape == (4, 500, 100) # single level component