Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
583 changes: 583 additions & 0 deletions notebooks/structural_components_dataclass.ipynb

Large diffs are not rendered by default.

241 changes: 241 additions & 0 deletions pymc_extras/statespace/core/properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from __future__ import annotations

import warnings

from collections.abc import Iterator
from copy import deepcopy
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Generic, Self, TypeVar

from pymc_extras.statespace.core import PyMCStateSpace
from pymc_extras.statespace.utils.constants import (
ALL_STATE_AUX_DIM,
ALL_STATE_DIM,
OBS_STATE_AUX_DIM,
OBS_STATE_DIM,
SHOCK_AUX_DIM,
SHOCK_DIM,
)

if TYPE_CHECKING:
from pymc_extras.statespace.models.structural.core import Component


@dataclass(frozen=True)
class Property:
def __str__(self) -> str:
return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))


T = TypeVar("T", bound=Property)


@dataclass(frozen=True)
class Info(Generic[T]):
items: tuple[T, ...]
key_field: str = "name"
_index: dict[str, T] | None = None

def __post_init__(self):
index = {}
missing_attr = []
for item in self.items:
if not hasattr(item, self.key_field):
missing_attr.append(item)
continue
key = getattr(item, self.key_field)
# if key in index:
# raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") # This needs to be possible for shared states
index[key] = item
if missing_attr:
raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}")
object.__setattr__(self, "_index", index)

def _key(self, item: T) -> str:
return getattr(item, self.key_field)

def get(self, key: str, default=None) -> T | None:
return self._index.get(key, default)

def __getitem__(self, key: str) -> T:
try:
return self._index[key]
except KeyError as e:
available = ", ".join(self._index.keys())
raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e

def __contains__(self, key: object) -> bool:
return key in self._index

def __iter__(self) -> Iterator[T]:
return iter(self.items)

def __len__(self) -> int:
return len(self.items)

def __str__(self) -> str:
return f"{self.key_field}s: {list(self._index.keys())}"

def add(self, new_item: T):
return type(self)([*self.items, new_item])

def merge(self, other: Self, allow_duplicates: bool = False) -> Self:
if not isinstance(other, type(self)):
raise TypeError(f"Cannot merge {type(other).__name__} with {type(self).__name__}")

overlapping = set(self.names) & set(other.names)
if overlapping and not allow_duplicates:
raise ValueError(f"Duplicate names found: {overlapping}")

return type(self)(list(self.items) + list(other.items))

@property
def names(self) -> tuple[str, ...]:
return tuple(self._index.keys())

def copy(self) -> Info[T]:
return deepcopy(self)


@dataclass(frozen=True)
class Parameter(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
constraints: str | None = None


@dataclass(frozen=True)
class ParameterInfo(Info[Parameter]):
def __init__(self, parameters: list[Parameter]):
super().__init__(items=tuple(parameters), key_field="name")


@dataclass(frozen=True)
class Data(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
is_exogenous: bool


@dataclass(frozen=True)
class DataInfo(Info[Data]):
def __init__(self, data: list[Data]):
super().__init__(items=tuple(data), key_field="name")

@property
def needs_exogenous_data(self) -> bool:
return any(d.is_exogenous for d in self.items)

@property
def exogenous_names(self) -> tuple[str, ...]:
return tuple(d.name for d in self.items if d.is_exogenous)

def __str__(self) -> str:
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"


@dataclass(frozen=True)
class Coord(Property):
dimension: str
labels: tuple[str, ...]


@dataclass(frozen=True)
class CoordInfo(Info[Coord]):
def __init__(self, coords: list[Coord]):
super().__init__(items=tuple(coords), key_field="dimension")

def __str__(self) -> str:
base = "coordinates:"
for coord in self.items:
coord_str = str(coord)
indented = "\n".join(" " + line for line in coord_str.splitlines())
base += "\n" + indented + "\n"
return base

@classmethod
def default_coords_from_model(
cls, model: PyMCStateSpace | Component
) -> (
Self
): # TODO: Need to figure out how to include Component type was causing circular import issues
states = tuple(model.state_names)
obs_states = tuple(model.observed_states)
shocks = tuple(model.shock_names)

dim_to_labels = (
(ALL_STATE_DIM, states),
(ALL_STATE_AUX_DIM, states),
(OBS_STATE_DIM, obs_states),
(OBS_STATE_AUX_DIM, obs_states),
(SHOCK_DIM, shocks),
(SHOCK_AUX_DIM, shocks),
)

coords = [Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels]
return cls(coords)

def to_dict(self):
return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0}


@dataclass(frozen=True)
class State(Property):
name: str
observed: bool
shared: bool


@dataclass(frozen=True)
class StateInfo(Info[State]):
def __init__(self, states: list[State]):
super().__init__(items=tuple(states), key_field="name")

def __str__(self) -> str:
return (
f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}"
)

@property
def observed_states(self) -> tuple[State, ...]: # Is this needed??
return tuple(s for s in self.items if s.observed)

@property
def observed_state_names(self) -> tuple[State, ...]:
return tuple(s.name for s in self.items if s.observed)

@property
def unobserved_state_names(self) -> tuple[State, ...]:
return tuple(s.name for s in self.items if not s.observed)

def merge(self, other: StateInfo, allow_duplicates: bool = False) -> StateInfo:
"""Combine states from two StateInfo objects."""
if not isinstance(other, StateInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo")

overlapping = set(self.names) & set(other.names)
if overlapping and not allow_duplicates:
# This is necessary for shared states
warnings.warn(
f"Duplicate state names found: {overlapping}. Merge will ONLY retain unique states",
UserWarning,
)
return StateInfo(
states=list(self.items)
+ [item for item in other.items if item.name not in overlapping]
)

return StateInfo(states=list(self.items) + list(other.items))


@dataclass(frozen=True)
class Shock(Property):
name: str


@dataclass(frozen=True)
class ShockInfo(Info[Shock]):
def __init__(self, shocks: list[Shock]):
super().__init__(items=tuple(shocks), key_field="name")
Loading
Loading