Skip to content

Commit f2579f1

Browse files
committed
Integration with state space
1 parent da47d90 commit f2579f1

File tree

4 files changed

+672
-14281
lines changed

4 files changed

+672
-14281
lines changed

causalpy/experiments/structural_time_series.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def __init__(
159159
"datetime_index": self.datapre.index, # For BSTS
160160
}
161161
self.pre_pred = self.model.predict(X=X_pre_predict, coords=pre_pred_coords)
162+
if not isinstance(self.pre_pred, az.InferenceData):
163+
self.pre_pred = az.InferenceData(posterior_predictive=self.pre_pred)
162164
elif isinstance(self.model, RegressorMixin):
163165
self.pre_pred = self.model.predict(X=X_pre_predict)
164166
else:
@@ -173,8 +175,10 @@ def __init__(
173175
"datetime_index": self.datapost.index, # For BSTS
174176
}
175177
self.post_pred = self.model.predict(
176-
X=X_post_predict, coords=post_pred_coords
178+
X=X_post_predict, coords=post_pred_coords, out_of_sample=True
177179
)
180+
if not isinstance(self.post_pred, az.InferenceData):
181+
self.post_pred = az.InferenceData(posterior_predictive=self.post_pred)
178182
elif isinstance(self.model, RegressorMixin):
179183
self.post_pred = self.model.predict(X=X_post_predict)
180184
else:

causalpy/pymc_models.py

Lines changed: 265 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,13 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
122122
)
123123
return self.idata
124124

125-
def predict(self, X, coords: Optional[Dict[str, Any]] = None, **kwargs):
125+
def predict(
126+
self,
127+
X,
128+
coords: Optional[Dict[str, Any]] = None,
129+
out_of_sample: Optional[bool] = False,
130+
**kwargs,
131+
):
126132
"""
127133
Predict data given input data `X`
128134
@@ -983,6 +989,7 @@ def predict(
983989
self,
984990
X: Optional[np.ndarray],
985991
coords: Dict[str, Any], # Must contain "datetime_index" for prediction period
992+
out_of_sample: Optional[bool] = False,
986993
):
987994
"""
988995
Predict data given input X and coords for prediction period.
@@ -1018,3 +1025,260 @@ def score(
10181025
).T.values
10191026
# Note: First argument must be a 1D array
10201027
return r2_score(y.flatten(), mu_pred)
1028+
1029+
1030+
class StateSpaceTimeSeries(PyMCModel):
1031+
"""
1032+
State-space time series model using pymc_extras.statespace.structural.
1033+
1034+
Parameters
1035+
----------
1036+
level_order : int, optional
1037+
Order of the local level/trend component. Defaults to 2.
1038+
seasonal_length : int, optional
1039+
Seasonal period (e.g., 12 for monthly data with annual seasonality). Defaults to 12.
1040+
trend_component : optional
1041+
Custom state-space trend component.
1042+
seasonality_component : optional
1043+
Custom state-space seasonal component.
1044+
sample_kwargs : dict, optional
1045+
Kwargs passed to `pm.sample`.
1046+
mode : str, optional
1047+
Mode passed to `build_statespace_graph` (e.g., "JAX").
1048+
"""
1049+
1050+
def __init__(
1051+
self,
1052+
level_order: int = 2,
1053+
seasonal_length: int = 12,
1054+
trend_component: Optional[Any] = None,
1055+
seasonality_component: Optional[Any] = None,
1056+
sample_kwargs: Optional[Dict[str, Any]] = None,
1057+
mode: str = "JAX",
1058+
):
1059+
super().__init__(sample_kwargs=sample_kwargs)
1060+
self._custom_trend_component = trend_component
1061+
self._custom_seasonality_component = seasonality_component
1062+
self.level_order = level_order
1063+
self.seasonal_length = seasonal_length
1064+
self.mode = mode
1065+
self.ss_mod = None
1066+
self._validate_and_initialize_components()
1067+
1068+
def _validate_and_initialize_components(self):
1069+
"""
1070+
Validate and initialize trend and seasonality components.
1071+
This separates validation from model building for cleaner code.
1072+
"""
1073+
# Validate pymc-extras availability if using default components
1074+
if (
1075+
self._custom_trend_component is None
1076+
or self._custom_seasonality_component is None
1077+
):
1078+
try:
1079+
from pymc_extras.statespace import structural as st
1080+
1081+
self._PymcExtrasLevelTrendComponent = st.LevelTrendComponent
1082+
self._PymcExtrasFrequencySeasonality = st.FrequencySeasonality
1083+
except ImportError:
1084+
raise ImportError(
1085+
"pymc-extras is required when using default trend or seasonality components. "
1086+
"Please install it with `conda install -c conda-forge pymc-extras` or provide custom components."
1087+
)
1088+
1089+
# Validate custom components have required methods
1090+
if self._custom_trend_component is not None:
1091+
if not hasattr(self._custom_trend_component, "apply"):
1092+
raise ValueError(
1093+
"Custom trend_component must have an 'apply' method that accepts time data "
1094+
"and returns a PyMC tensor."
1095+
)
1096+
1097+
if self._custom_seasonality_component is not None:
1098+
if not hasattr(self._custom_seasonality_component, "apply"):
1099+
raise ValueError(
1100+
"Custom seasonality_component must have an 'apply' method that accepts time data "
1101+
"and returns a PyMC tensor."
1102+
)
1103+
1104+
# Initialize components
1105+
self._trend_component = None
1106+
self._seasonality_component = None
1107+
1108+
def _get_trend_component(self):
1109+
"""Get the trend component, creating default if needed."""
1110+
if self._custom_trend_component is not None:
1111+
return self._custom_trend_component
1112+
1113+
# Create default trend component
1114+
if self._trend_component is None:
1115+
self._trend_component = self._PymcExtrasLevelTrendComponent(
1116+
order=self.level_order
1117+
)
1118+
return self._trend_component
1119+
1120+
def _get_seasonality_component(self):
1121+
"""Get the seasonality component, creating default if needed."""
1122+
if self._custom_seasonality_component is not None:
1123+
return self._custom_seasonality_component
1124+
1125+
# Create default seasonality component
1126+
if self._seasonality_component is None:
1127+
self._seasonality_component = self._PymcExtrasFrequencySeasonality(
1128+
season_length=self.seasonal_length, name="freq"
1129+
)
1130+
return self._seasonality_component
1131+
1132+
def build_model(
1133+
self, X: Optional[np.ndarray], y: np.ndarray, coords: Dict[str, Any]
1134+
) -> None:
1135+
"""
1136+
Build the PyMC state-space model. `coords` must include:
1137+
- 'datetime_index': a pandas.DatetimeIndex matching `y`.
1138+
"""
1139+
coords = coords.copy()
1140+
datetime_index = coords.pop("datetime_index", None)
1141+
if not isinstance(datetime_index, pd.DatetimeIndex):
1142+
raise ValueError(
1143+
"coords must contain 'datetime_index' of type pandas.DatetimeIndex."
1144+
)
1145+
self._train_index = datetime_index
1146+
1147+
# Instantiate components and build state-space object
1148+
trend = self._get_trend_component()
1149+
season = self._get_seasonality_component()
1150+
combined = trend + season
1151+
self.ss_mod = combined.build()
1152+
1153+
# Extract parameter dims (order: initial_trend, sigma_trend, seasonal, P0)
1154+
initial_trend_dims, sigma_trend_dims, annual_dims, P0_dims = (
1155+
self.ss_mod.param_dims.values()
1156+
)
1157+
coordinates = {**coords, **self.ss_mod.coords}
1158+
1159+
# Build model
1160+
with pm.Model(coords=coordinates) as self.second_model:
1161+
# Add coords for statespace (includes 'time' and 'state' dims)
1162+
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=1, dims=P0_dims[0])
1163+
_P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)
1164+
_initial_trend = pm.Normal(
1165+
"initial_trend", sigma=50, dims=initial_trend_dims
1166+
)
1167+
_annual_seasonal = pm.ZeroSumNormal("freq", sigma=80, dims=annual_dims)
1168+
1169+
_sigma_trend = pm.Gamma(
1170+
"sigma_trend", alpha=2, beta=5, dims=sigma_trend_dims
1171+
)
1172+
_sigma_monthly_season = pm.Gamma("sigma_freq", alpha=2, beta=1)
1173+
1174+
# Attach the state-space graph using the observed data
1175+
df = pd.DataFrame({"y": y.flatten()}, index=datetime_index)
1176+
self.ss_mod.build_statespace_graph(df[["y"]], mode=self.mode)
1177+
1178+
def fit(
1179+
self, X: Optional[np.ndarray], y: np.ndarray, coords: Dict[str, Any]
1180+
) -> az.InferenceData:
1181+
"""
1182+
Fit the model, drawing posterior samples.
1183+
Returns the InferenceData with parameter draws.
1184+
"""
1185+
self.build_model(X, y, coords)
1186+
with self.second_model:
1187+
self.idata = pm.sample(**self.sample_kwargs)
1188+
self.idata.extend(
1189+
pm.sample_posterior_predictive(
1190+
self.idata,
1191+
)
1192+
)
1193+
self.conditional_idata = self._smooth()
1194+
return self._prepare_idata()
1195+
1196+
def _prepare_idata(self):
1197+
if self.idata is None:
1198+
raise RuntimeError("Model must be fit before smoothing.")
1199+
1200+
new_idata = self.idata.copy()
1201+
# Get smoothed posterior and sum over state dimension
1202+
smoothed = self.conditional_idata.rename({"smoothed_posterior": "y_hat"})
1203+
y_hat_summed = smoothed.y_hat.sum(dim="state")
1204+
1205+
# Rename 'time' to 'obs_ind' to match CausalPy conventions
1206+
if "time" in y_hat_summed.dims:
1207+
y_hat_final = y_hat_summed.rename({"time": "obs_ind"})
1208+
else:
1209+
y_hat_final = y_hat_summed
1210+
1211+
new_idata["posterior_predictive"]["y_hat"] = y_hat_final
1212+
new_idata["posterior_predictive"]["mu"] = y_hat_final
1213+
1214+
return new_idata
1215+
1216+
def _smooth(self) -> xr.Dataset:
1217+
"""
1218+
Run the Kalman smoother / conditional posterior sampler.
1219+
Returns an xarray Dataset with 'smoothed_posterior'.
1220+
"""
1221+
if self.idata is None:
1222+
raise RuntimeError("Model must be fit before smoothing.")
1223+
return self.ss_mod.sample_conditional_posterior(self.idata)
1224+
1225+
def _forecast(self, start: pd.Timestamp, periods: int) -> xr.Dataset:
1226+
"""
1227+
Forecast future values.
1228+
`start` is the timestamp of the last observed point, and `periods` is the number of steps ahead.
1229+
Returns an xarray Dataset with 'forecast_observed'.
1230+
"""
1231+
if self.idata is None:
1232+
raise RuntimeError("Model must be fit before forecasting.")
1233+
return self.ss_mod.forecast(self.idata, start=start, periods=periods)
1234+
1235+
def predict(
1236+
self,
1237+
X: Optional[np.ndarray],
1238+
coords: Dict[str, Any],
1239+
out_of_sample: Optional[bool] = False,
1240+
) -> xr.Dataset:
1241+
"""
1242+
Wrapper around forecast: expects coords with 'datetime_index' of future points.
1243+
"""
1244+
if not out_of_sample:
1245+
return self._prepare_idata()
1246+
else:
1247+
idx = coords.get("datetime_index")
1248+
if not isinstance(idx, pd.DatetimeIndex):
1249+
raise ValueError(
1250+
"coords must contain 'datetime_index' for prediction period."
1251+
)
1252+
last = self._train_index[-1] # start forecasting after the last observed
1253+
temp_idata = self._forecast(start=last, periods=len(idx))
1254+
new_idata = temp_idata.copy()
1255+
1256+
# Rename 'time' to 'obs_ind' to match CausalPy conventions
1257+
if "time" in new_idata.dims:
1258+
new_idata = new_idata.rename({"time": "obs_ind"})
1259+
1260+
# Extract the forecasted observed data and assign it to 'y_hat'
1261+
new_idata["y_hat"] = new_idata["forecast_observed"].isel(observed_state=0)
1262+
1263+
# Assign 'y_hat' to 'mu' for consistency
1264+
new_idata["mu"] = new_idata["y_hat"]
1265+
1266+
return new_idata
1267+
1268+
def score(
1269+
self, X: Optional[np.ndarray], y: np.ndarray, coords: Dict[str, Any]
1270+
) -> pd.Series:
1271+
"""
1272+
Compute R^2 between observed and mean forecast.
1273+
"""
1274+
pred = self.predict(X, coords)
1275+
fc = pred["posterior_predictive"]["y_hat"] # .isel(observed_state=0)
1276+
1277+
# Use all posterior samples to compute Bayesian R²
1278+
# fc has shape (chain, draw, time), we want (n_samples, time)
1279+
fc_samples = fc.stack(
1280+
sample=["chain", "draw"]
1281+
).T.values # Shape: (time, n_samples)
1282+
1283+
# Use arviz.r2_score to get both r2 and r2_std
1284+
return r2_score(y.flatten(), fc_samples)

0 commit comments

Comments
 (0)