Skip to content

Improve xarray display of structural components for multivariate time series #555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 45 additions & 27 deletions pymc_extras/statespace/models/structural/components/seasonality.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def __init__(
state_names: list | None = None,
remove_first_state: bool = True,
observed_state_names: list[str] | None = None,
share_states: bool = False,
):
if observed_state_names is None:
observed_state_names = ["data"]
Expand All @@ -261,6 +262,7 @@ def __init__(
)
state_names = state_names.copy()

self.share_states = share_states
self.innovations = innovations
self.duration = duration
self.remove_first_state = remove_first_state
Expand All @@ -281,44 +283,53 @@ def __init__(
super().__init__(
name=name,
k_endog=k_endog,
k_states=k_states * k_endog,
k_posdef=k_posdef * k_endog,
k_states=k_states if share_states else k_states * k_endog,
k_posdef=k_posdef if share_states else k_posdef * k_endog,
observed_state_names=observed_state_names,
measurement_error=False,
combine_hidden_states=True,
obs_state_idxs=np.tile(np.array([1.0] + [0.0] * (k_states - 1)), k_endog),
obs_state_idxs=np.tile(
np.array([1.0] + [0.0] * (k_states - 1)), 1 if share_states else k_endog
),
share_states=share_states,
)

def populate_component_properties(self):
k_states = self.k_states // self.k_endog
k_endog = self.k_endog
k_endog_effective = 1 if self.share_states else k_endog
k_states = self.k_states // k_endog_effective

self.state_names = [
f"{state_name}[{endog_name}]"
for endog_name in self.observed_state_names
for state_name in self.provided_state_names
]
if self.share_states:
self.state_names = [
f"{state_name}[{self.name}_shared]" for state_name in self.provided_state_names
]
else:
self.state_names = [
f"{state_name}[{endog_name}]"
for endog_name in self.observed_state_names
for state_name in self.provided_state_names
]
self.param_names = [f"params_{self.name}"]

self.param_info = {
f"params_{self.name}": {
"shape": (k_states,) if k_endog == 1 else (k_endog, k_states),
"shape": (k_states,) if k_endog_effective == 1 else (k_endog_effective, k_states),
"constraints": None,
"dims": (f"state_{self.name}",)
if k_endog == 1
if k_endog_effective == 1
else (f"endog_{self.name}", f"state_{self.name}"),
}
}

self.param_dims = {
f"params_{self.name}": (f"state_{self.name}",)
if k_endog == 1
if k_endog_effective == 1
else (f"endog_{self.name}", f"state_{self.name}")
}

self.coords = (
{f"state_{self.name}": self.provided_state_names}
if k_endog == 1
if k_endog_effective == 1
else {
f"endog_{self.name}": self.observed_state_names,
f"state_{self.name}": self.provided_state_names,
Expand All @@ -327,21 +338,26 @@ def populate_component_properties(self):

if self.innovations:
self.param_names += [f"sigma_{self.name}"]
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]
self.param_info[f"sigma_{self.name}"] = {
"shape": () if k_endog == 1 else (k_endog,),
"shape": () if k_endog_effective == 1 else (k_endog_effective,),
"constraints": "Positive",
"dims": None if k_endog == 1 else (f"endog_{self.name}",),
"dims": None if k_endog_effective == 1 else (f"endog_{self.name}",),
}
if k_endog > 1:
if k_endog_effective > 1:
self.param_dims[f"sigma_{self.name}"] = (f"endog_{self.name}",)

if self.share_states:
self.shock_names = [f"{self.name}[shared]"]
else:
self.shock_names = [f"{self.name}[{name}]" for name in self.observed_state_names]

def make_symbolic_graph(self) -> None:
k_states = self.k_states // self.k_endog
k_endog = self.k_endog
k_endog_effective = 1 if self.share_states else k_endog
k_states = self.k_states // k_endog_effective
duration = self.duration
k_unique_states = k_states // duration
k_posdef = self.k_posdef // self.k_endog
k_endog = self.k_endog
k_posdef = self.k_posdef // k_endog_effective

if self.remove_first_state:
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
Expand Down Expand Up @@ -373,16 +389,18 @@ def make_symbolic_graph(self) -> None:
T = pt.eye(k_states, k=1)
T = pt.set_subtensor(T[-1, 0], 1)

self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog)])
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog_effective)])

Z = pt.zeros((1, k_states))[0, 0].set(1)
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog)])
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog_effective)])

initial_states = self.make_and_register_variable(
f"params_{self.name}",
shape=(k_unique_states,) if k_endog == 1 else (k_endog, k_unique_states),
shape=(k_unique_states,)
if k_endog_effective == 1
else (k_endog_effective, k_unique_states),
)
if k_endog == 1:
if k_endog_effective == 1:
self.ssm["initial_state", :] = pt.extra_ops.repeat(initial_states, duration, axis=0)
else:
self.ssm["initial_state", :] = pt.extra_ops.repeat(
Expand All @@ -391,11 +409,11 @@ def make_symbolic_graph(self) -> None:

if self.innovations:
R = pt.zeros((k_states, k_posdef))[0, 0].set(1.0)
self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog)])
self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog_effective)])
season_sigma = self.make_and_register_variable(
f"sigma_{self.name}", shape=() if k_endog == 1 else (k_endog,)
f"sigma_{self.name}", shape=() if k_endog_effective == 1 else (k_endog_effective,)
)
cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog))
cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog_effective))
self.ssm[cov_idx] = season_sigma**2


Expand Down
56 changes: 47 additions & 9 deletions pymc_extras/statespace/models/structural/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
join_tensors_by_dim_labels,
make_default_coords,
)
from pymc_extras.statespace.utils.component_parsing import restructure_components_idata
from pymc_extras.statespace.utils.constants import (
ALL_STATE_AUX_DIM,
ALL_STATE_DIM,
Expand Down Expand Up @@ -208,7 +209,7 @@ def __init__(
self._component_info = component_info.copy()

self._name_to_variable = name_to_variable.copy()
self._name_to_data = name_to_data.copy()
self._name_to_data = name_to_data.copy() if name_to_data is not None else {}

self._exog_names = exog_names.copy()
self._needs_exog_data = len(exog_names) > 0
Expand Down Expand Up @@ -318,9 +319,18 @@ def _hidden_states_from_data(self, data):

if info[name]["combine_hidden_states"]:
sum_idx_joined = np.flatnonzero(obs_idx)
sum_idx_split = np.split(sum_idx_joined, info[name]["k_endog"])
for sum_idx in sum_idx_split:
result.append(X[..., sum_idx].sum(axis=-1)[..., None])
k_endog = info[name]["k_endog"]

if info[name]["share_states"]:
# sum once and replicate for each endogenous variable
shared_sum = X[..., sum_idx_joined].sum(axis=-1)[..., None]
for _ in range(k_endog):
result.append(shared_sum)
else:
# states are separate
sum_idx_split = np.split(sum_idx_joined, k_endog)
for sum_idx in sum_idx_split:
result.append(X[..., sum_idx].sum(axis=-1)[..., None])
else:
n_components = len(self.state_names[s])
for j in range(n_components):
Expand Down Expand Up @@ -350,20 +360,27 @@ def _get_subcomponent_names(self):
result.extend([f"{name}[{comp_name}]" for comp_name in comp_names])
return result

def extract_components_from_idata(self, idata: xr.Dataset) -> xr.Dataset:
def extract_components_from_idata(
self, idata: xr.Dataset, restructure: bool = False
) -> xr.Dataset:
r"""
Extract interpretable hidden states from an InferenceData returned by a PyMCStateSpace sampling method

Parameters
----------
idata: Dataset
idata : Dataset
A Dataset object, returned by a PyMCStateSpace sampling method
restructure : bool, default False
Whether to restructure the state coordinates as a multi-index for easier component selection.
When True, enables selections like `idata.sel(component='level')` and `idata.sel(observed='gdp')`.
Particularly useful for multivariate models with multiple observed states.

Returns
-------
idata: Dataset
idata : Dataset
A Dataset object with hidden states transformed to represent only the "interpretable" subcomponents
of the structural model.
of the structural model. If `restructure=True`, the state coordinate will be a multi-index with
levels ['component', 'observed'] for easier selection.

Notes
-----
Expand All @@ -383,9 +400,12 @@ def extract_components_from_idata(self, idata: xr.Dataset) -> xr.Dataset:
- :math:`\varepsilon_t` is the measurement error at time t

In state space form, some or all of these components are represented as linear combinations of other
subcomponents, making interpretation of the outputs of the outputs difficult. The purpose of this function is
subcomponents, making interpretation of the outputs difficult. The purpose of this function is
to take the expended statespace representation and return a "reduced form" of only the components shown in
equation (1).

When `restructure=True`, the returned dataset allows for easy component selection, especially for
multivariate models with multiple observed states.
"""

def _extract_and_transform_variable(idata, new_state_names):
Expand Down Expand Up @@ -423,6 +443,17 @@ def _extract_and_transform_variable(idata, new_state_names):
for name in latent_names
}
)

if restructure:
try:
idata_new = restructure_components_idata(idata_new)
except Exception as e:
_log.warning(
f"Failed to restructure components with multi-index: {e}. "
"Returning dataset with original string-based state names. "
"You can call restructure_components_idata() manually if needed."
)

return idata_new


Expand Down Expand Up @@ -471,6 +502,10 @@ class Component:
obs_state_idxs : np.ndarray | None, optional
Indices indicating which states contribute to observed variables. If None,
defaults to None.
share_states : bool, optional
Whether states are shared across multiple endogenous variables in multivariate
models. When True, the same latent states affect all observed variables.
Default is False.

Examples
--------
Expand Down Expand Up @@ -512,10 +547,12 @@ def __init__(
combine_hidden_states=True,
component_from_sum=False,
obs_state_idxs=None,
share_states: bool = False,
):
self.name = name
self.k_endog = k_endog
self.k_states = k_states
self.share_states = share_states
self.k_posdef = k_posdef
self.measurement_error = measurement_error

Expand Down Expand Up @@ -557,6 +594,7 @@ def __init__(
"observed_state_names": self.observed_state_names,
"combine_hidden_states": combine_hidden_states,
"obs_state_idx": obs_state_idxs,
"share_states": self.share_states,
}
}

Expand Down
11 changes: 11 additions & 0 deletions pymc_extras/statespace/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .component_parsing import (
create_component_multiindex,
parse_component_state_name,
restructure_components_idata,
)

__all__ = [
"create_component_multiindex",
"parse_component_state_name",
"restructure_components_idata",
]
Loading
Loading