@@ -122,7 +122,13 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
122122 )
123123 return self .idata
124124
125- def predict (self , X , coords : Optional [Dict [str , Any ]] = None , ** kwargs ):
125+ def predict (
126+ self ,
127+ X ,
128+ coords : Optional [Dict [str , Any ]] = None ,
129+ out_of_sample : Optional [bool ] = False ,
130+ ** kwargs ,
131+ ):
126132 """
127133 Predict data given input data `X`
128134
@@ -983,6 +989,7 @@ def predict(
983989 self ,
984990 X : Optional [np .ndarray ],
985991 coords : Dict [str , Any ], # Must contain "datetime_index" for prediction period
992+ out_of_sample : Optional [bool ] = False ,
986993 ):
987994 """
988995 Predict data given input X and coords for prediction period.
@@ -1018,3 +1025,260 @@ def score(
10181025 ).T .values
10191026 # Note: First argument must be a 1D array
10201027 return r2_score (y .flatten (), mu_pred )
1028+
1029+
1030+ class StateSpaceTimeSeries (PyMCModel ):
1031+ """
1032+ State-space time series model using pymc_extras.statespace.structural.
1033+
1034+ Parameters
1035+ ----------
1036+ level_order : int, optional
1037+ Order of the local level/trend component. Defaults to 2.
1038+ seasonal_length : int, optional
1039+ Seasonal period (e.g., 12 for monthly data with annual seasonality). Defaults to 12.
1040+ trend_component : optional
1041+ Custom state-space trend component.
1042+ seasonality_component : optional
1043+ Custom state-space seasonal component.
1044+ sample_kwargs : dict, optional
1045+ Kwargs passed to `pm.sample`.
1046+ mode : str, optional
1047+ Mode passed to `build_statespace_graph` (e.g., "JAX").
1048+ """
1049+
1050+ def __init__ (
1051+ self ,
1052+ level_order : int = 2 ,
1053+ seasonal_length : int = 12 ,
1054+ trend_component : Optional [Any ] = None ,
1055+ seasonality_component : Optional [Any ] = None ,
1056+ sample_kwargs : Optional [Dict [str , Any ]] = None ,
1057+ mode : str = "JAX" ,
1058+ ):
1059+ super ().__init__ (sample_kwargs = sample_kwargs )
1060+ self ._custom_trend_component = trend_component
1061+ self ._custom_seasonality_component = seasonality_component
1062+ self .level_order = level_order
1063+ self .seasonal_length = seasonal_length
1064+ self .mode = mode
1065+ self .ss_mod = None
1066+ self ._validate_and_initialize_components ()
1067+
1068+ def _validate_and_initialize_components (self ):
1069+ """
1070+ Validate and initialize trend and seasonality components.
1071+ This separates validation from model building for cleaner code.
1072+ """
1073+ # Validate pymc-extras availability if using default components
1074+ if (
1075+ self ._custom_trend_component is None
1076+ or self ._custom_seasonality_component is None
1077+ ):
1078+ try :
1079+ from pymc_extras .statespace import structural as st
1080+
1081+ self ._PymcExtrasLevelTrendComponent = st .LevelTrendComponent
1082+ self ._PymcExtrasFrequencySeasonality = st .FrequencySeasonality
1083+ except ImportError :
1084+ raise ImportError (
1085+ "pymc-extras is required when using default trend or seasonality components. "
1086+ "Please install it with `conda install -c conda-forge pymc-extras` or provide custom components."
1087+ )
1088+
1089+ # Validate custom components have required methods
1090+ if self ._custom_trend_component is not None :
1091+ if not hasattr (self ._custom_trend_component , "apply" ):
1092+ raise ValueError (
1093+ "Custom trend_component must have an 'apply' method that accepts time data "
1094+ "and returns a PyMC tensor."
1095+ )
1096+
1097+ if self ._custom_seasonality_component is not None :
1098+ if not hasattr (self ._custom_seasonality_component , "apply" ):
1099+ raise ValueError (
1100+ "Custom seasonality_component must have an 'apply' method that accepts time data "
1101+ "and returns a PyMC tensor."
1102+ )
1103+
1104+ # Initialize components
1105+ self ._trend_component = None
1106+ self ._seasonality_component = None
1107+
1108+ def _get_trend_component (self ):
1109+ """Get the trend component, creating default if needed."""
1110+ if self ._custom_trend_component is not None :
1111+ return self ._custom_trend_component
1112+
1113+ # Create default trend component
1114+ if self ._trend_component is None :
1115+ self ._trend_component = self ._PymcExtrasLevelTrendComponent (
1116+ order = self .level_order
1117+ )
1118+ return self ._trend_component
1119+
1120+ def _get_seasonality_component (self ):
1121+ """Get the seasonality component, creating default if needed."""
1122+ if self ._custom_seasonality_component is not None :
1123+ return self ._custom_seasonality_component
1124+
1125+ # Create default seasonality component
1126+ if self ._seasonality_component is None :
1127+ self ._seasonality_component = self ._PymcExtrasFrequencySeasonality (
1128+ season_length = self .seasonal_length , name = "freq"
1129+ )
1130+ return self ._seasonality_component
1131+
1132+ def build_model (
1133+ self , X : Optional [np .ndarray ], y : np .ndarray , coords : Dict [str , Any ]
1134+ ) -> None :
1135+ """
1136+ Build the PyMC state-space model. `coords` must include:
1137+ - 'datetime_index': a pandas.DatetimeIndex matching `y`.
1138+ """
1139+ coords = coords .copy ()
1140+ datetime_index = coords .pop ("datetime_index" , None )
1141+ if not isinstance (datetime_index , pd .DatetimeIndex ):
1142+ raise ValueError (
1143+ "coords must contain 'datetime_index' of type pandas.DatetimeIndex."
1144+ )
1145+ self ._train_index = datetime_index
1146+
1147+ # Instantiate components and build state-space object
1148+ trend = self ._get_trend_component ()
1149+ season = self ._get_seasonality_component ()
1150+ combined = trend + season
1151+ self .ss_mod = combined .build ()
1152+
1153+ # Extract parameter dims (order: initial_trend, sigma_trend, seasonal, P0)
1154+ initial_trend_dims , sigma_trend_dims , annual_dims , P0_dims = (
1155+ self .ss_mod .param_dims .values ()
1156+ )
1157+ coordinates = {** coords , ** self .ss_mod .coords }
1158+
1159+ # Build model
1160+ with pm .Model (coords = coordinates ) as self .second_model :
1161+ # Add coords for statespace (includes 'time' and 'state' dims)
1162+ P0_diag = pm .Gamma ("P0_diag" , alpha = 2 , beta = 1 , dims = P0_dims [0 ])
1163+ _P0 = pm .Deterministic ("P0" , pt .diag (P0_diag ), dims = P0_dims )
1164+ _initial_trend = pm .Normal (
1165+ "initial_trend" , sigma = 50 , dims = initial_trend_dims
1166+ )
1167+ _annual_seasonal = pm .ZeroSumNormal ("freq" , sigma = 80 , dims = annual_dims )
1168+
1169+ _sigma_trend = pm .Gamma (
1170+ "sigma_trend" , alpha = 2 , beta = 5 , dims = sigma_trend_dims
1171+ )
1172+ _sigma_monthly_season = pm .Gamma ("sigma_freq" , alpha = 2 , beta = 1 )
1173+
1174+ # Attach the state-space graph using the observed data
1175+ df = pd .DataFrame ({"y" : y .flatten ()}, index = datetime_index )
1176+ self .ss_mod .build_statespace_graph (df [["y" ]], mode = self .mode )
1177+
1178+ def fit (
1179+ self , X : Optional [np .ndarray ], y : np .ndarray , coords : Dict [str , Any ]
1180+ ) -> az .InferenceData :
1181+ """
1182+ Fit the model, drawing posterior samples.
1183+ Returns the InferenceData with parameter draws.
1184+ """
1185+ self .build_model (X , y , coords )
1186+ with self .second_model :
1187+ self .idata = pm .sample (** self .sample_kwargs )
1188+ self .idata .extend (
1189+ pm .sample_posterior_predictive (
1190+ self .idata ,
1191+ )
1192+ )
1193+ self .conditional_idata = self ._smooth ()
1194+ return self ._prepare_idata ()
1195+
1196+ def _prepare_idata (self ):
1197+ if self .idata is None :
1198+ raise RuntimeError ("Model must be fit before smoothing." )
1199+
1200+ new_idata = self .idata .copy ()
1201+ # Get smoothed posterior and sum over state dimension
1202+ smoothed = self .conditional_idata .rename ({"smoothed_posterior" : "y_hat" })
1203+ y_hat_summed = smoothed .y_hat .sum (dim = "state" )
1204+
1205+ # Rename 'time' to 'obs_ind' to match CausalPy conventions
1206+ if "time" in y_hat_summed .dims :
1207+ y_hat_final = y_hat_summed .rename ({"time" : "obs_ind" })
1208+ else :
1209+ y_hat_final = y_hat_summed
1210+
1211+ new_idata ["posterior_predictive" ]["y_hat" ] = y_hat_final
1212+ new_idata ["posterior_predictive" ]["mu" ] = y_hat_final
1213+
1214+ return new_idata
1215+
1216+ def _smooth (self ) -> xr .Dataset :
1217+ """
1218+ Run the Kalman smoother / conditional posterior sampler.
1219+ Returns an xarray Dataset with 'smoothed_posterior'.
1220+ """
1221+ if self .idata is None :
1222+ raise RuntimeError ("Model must be fit before smoothing." )
1223+ return self .ss_mod .sample_conditional_posterior (self .idata )
1224+
1225+ def _forecast (self , start : pd .Timestamp , periods : int ) -> xr .Dataset :
1226+ """
1227+ Forecast future values.
1228+ `start` is the timestamp of the last observed point, and `periods` is the number of steps ahead.
1229+ Returns an xarray Dataset with 'forecast_observed'.
1230+ """
1231+ if self .idata is None :
1232+ raise RuntimeError ("Model must be fit before forecasting." )
1233+ return self .ss_mod .forecast (self .idata , start = start , periods = periods )
1234+
1235+ def predict (
1236+ self ,
1237+ X : Optional [np .ndarray ],
1238+ coords : Dict [str , Any ],
1239+ out_of_sample : Optional [bool ] = False ,
1240+ ) -> xr .Dataset :
1241+ """
1242+ Wrapper around forecast: expects coords with 'datetime_index' of future points.
1243+ """
1244+ if not out_of_sample :
1245+ return self ._prepare_idata ()
1246+ else :
1247+ idx = coords .get ("datetime_index" )
1248+ if not isinstance (idx , pd .DatetimeIndex ):
1249+ raise ValueError (
1250+ "coords must contain 'datetime_index' for prediction period."
1251+ )
1252+ last = self ._train_index [- 1 ] # start forecasting after the last observed
1253+ temp_idata = self ._forecast (start = last , periods = len (idx ))
1254+ new_idata = temp_idata .copy ()
1255+
1256+ # Rename 'time' to 'obs_ind' to match CausalPy conventions
1257+ if "time" in new_idata .dims :
1258+ new_idata = new_idata .rename ({"time" : "obs_ind" })
1259+
1260+ # Extract the forecasted observed data and assign it to 'y_hat'
1261+ new_idata ["y_hat" ] = new_idata ["forecast_observed" ].isel (observed_state = 0 )
1262+
1263+ # Assign 'y_hat' to 'mu' for consistency
1264+ new_idata ["mu" ] = new_idata ["y_hat" ]
1265+
1266+ return new_idata
1267+
1268+ def score (
1269+ self , X : Optional [np .ndarray ], y : np .ndarray , coords : Dict [str , Any ]
1270+ ) -> pd .Series :
1271+ """
1272+ Compute R^2 between observed and mean forecast.
1273+ """
1274+ pred = self .predict (X , coords )
1275+ fc = pred ["posterior_predictive" ]["y_hat" ] # .isel(observed_state=0)
1276+
1277+ # Use all posterior samples to compute Bayesian R²
1278+ # fc has shape (chain, draw, time), we want (n_samples, time)
1279+ fc_samples = fc .stack (
1280+ sample = ["chain" , "draw" ]
1281+ ).T .values # Shape: (time, n_samples)
1282+
1283+ # Use arviz.r2_score to get both r2 and r2_std
1284+ return r2_score (y .flatten (), fc_samples )
0 commit comments