diff --git a/pymc_experimental/statespace/core/statespace.py b/pymc_experimental/statespace/core/statespace.py index 0b163dbda..c87e2ffdb 100644 --- a/pymc_experimental/statespace/core/statespace.py +++ b/pymc_experimental/statespace/core/statespace.py @@ -27,6 +27,7 @@ ) from pymc_experimental.statespace.filters.distributions import ( LinearGaussianStateSpace, + MvNormalSVD, SequenceMvNormal, ) from pymc_experimental.statespace.filters.utilities import stabilize @@ -876,9 +877,8 @@ def build_statespace_graph( cov_jitter=cov_jitter, ) - outputs = filter_outputs - logp = outputs.pop(-1) - states, covs = outputs[:3], outputs[3:] + logp = filter_outputs.pop(-1) + states, covs = filter_outputs[:3], filter_outputs[3:] filtered_states, predicted_states, observed_states = states filtered_covariances, predicted_covariances, observed_covariances = covs if save_kalman_filter_outputs_in_idata: @@ -976,6 +976,7 @@ def _kalman_filter_outputs_from_dummy_graph( self, data: pt.TensorLike | None = None, data_dims: str | tuple[str] | list[str] | None = None, + scenario: dict[str, pd.DataFrame] | pd.DataFrame | None = None, ) -> tuple[list[pt.TensorVariable], list[tuple[pt.TensorVariable, pt.TensorVariable]]]: """ Builds a Kalman filter graph using "dummy" pm.Flat distributions for the model variables and sorts the returns @@ -997,6 +998,9 @@ def _kalman_filter_outputs_from_dummy_graph( grouped_outputs: list of tuple of tensors A list of tuples, each containing the mean and covariance of the filtered, predicted, and smoothed states. """ + if scenario is None: + scenario = dict() + pm_mod = modelcontext(None) self._build_dummy_graph() self._insert_random_variables() @@ -1007,6 +1011,10 @@ def _kalman_filter_outputs_from_dummy_graph( self._insert_data_variables() + for name in self.data_names: + if name in scenario.keys(): + pm.set_data({name: scenario[name]}) + x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace() if data is None: @@ -1504,14 +1512,384 @@ def sample_statespace_matrices( return matrix_idata + @staticmethod + def _validate_forecast_args( + time_index: pd.RangeIndex | pd.DatetimeIndex, + start: int | pd.Timestamp, + periods: int | None = None, + end: int | pd.Timestamp = None, + scenario: pd.DataFrame | np.ndarray | None = None, + use_scenario_index: bool = False, + verbose: bool = True, + ): + if isinstance(start, pd.Timestamp) and start not in time_index: + raise ValueError("Datetime start must be in the data index used to fit the model.") + elif isinstance(start, int): + if abs(start) > len(time_index): + raise ValueError( + "Integer start must be within the range of the data index used to fit the model." + ) + if periods is None and end is None: + raise ValueError("Must specify one of either periods or end") + if periods is not None and end is not None: + raise ValueError("Must specify exactly one of either periods or end") + if scenario is None and use_scenario_index: + raise ValueError("use_scenario_index=True requires a scenario to be provided.") + if scenario is not None and use_scenario_index: + if isinstance(scenario, dict): + first_df = next( + (df for df in scenario.values() if isinstance(df, pd.DataFrame | pd.Series)), + None, + ) + if first_df is None: + raise ValueError( + "use_scenario_index=True requires a scenario to be a DataFrame or Series." + ) + elif not isinstance(scenario, pd.DataFrame | pd.Series): + raise ValueError( + "use_scenario_index=True requires a scenario to be a DataFrame or Series." + ) + if use_scenario_index and any(arg is not None for arg in [start, end, periods]) and verbose: + _log.warning( + "start, end, and periods arguments are ignored when use_scenario_index is True. Pass only " + "one or the other to avoid this warning, or pass verbose = False." + ) + + def _get_fit_time_index(self) -> pd.RangeIndex | pd.DatetimeIndex: + time_index = self._fit_coords.get(TIME_DIM, None) if self._fit_coords is not None else None + if time_index is None: + raise ValueError( + "No time dimension found on coordinates used to fit the model. Has this model been fit?" + ) + + if isinstance(time_index[0], pd.Timestamp): + time_index = pd.DatetimeIndex(time_index) + time_index.freq = time_index.inferred_freq + else: + time_index = np.array(time_index) + + return time_index + + def _validate_scenario_data( + self, + scenario: pd.DataFrame | np.ndarray | dict[str, pd.DataFrame | np.ndarray] | None, + name: str | None = None, + verbose=True, + ): + """ + Validate the scenario data provided to the forecast method by checking that it has the correct shape and + dimensions. + + Parameters + ---------- + scenario + name + verbose + + Returns + ------- + scenario: pd.DataFrame | np.ndarray | dict[str, pd.DataFrame | np.ndarray] + Scenario data, validated and potentially modified. + + """ + if not self._needs_exog_data: + return scenario + + var_to_dims = {key: info["dims"][1:] for key, info in self.data_info.items()} + + if any(len(dims) > 1 for dims in var_to_dims.values()): + raise NotImplementedError(">2d exogenous data is not yet supported.") + coords = { + var: self._fit_coords[dim[0]] + for var, dim in var_to_dims.items() + if dim[0] in self._fit_coords + } + + if self._needs_exog_data and scenario is None: + exog_str = ",".join(self._exog_names) + suffix = "s" if len(exog_str) > 1 else "" + raise ValueError( + f"This model was fit using exogenous data. Forecasting cannot be performed without " + f"providing scenario data for the following variable{suffix}: {exog_str}" + ) + + if isinstance(scenario, dict): + for name, data in scenario.items(): + if name not in self._exog_names: + raise ValueError( + f"Scenario data provided for variable '{name}', which is not an exogenous variable " + f"used to fit the model." + ) + + # Recursively call this function to trigger the non-dictionary branch of the checks on each object + # inside the dictionary + scenario[name] = self._validate_scenario_data(data, name) + + # The provided dictionary might be a mix of numpy arrays and dataframes if the user is truly horrible. + # For checking shapes, the first object will always be good enough. But we also need to make sure all the + # indices agree, so we grab the first dataframe (which might not exist, but that's OK) + first_scenario = next(iter(scenario.values())) + first_df = next((df for df in scenario.values() if isinstance(df, pd.DataFrame)), None) + + if not all(data.shape[0] == first_scenario.shape[0] for data in scenario.values()): + raise ValueError( + "Scenario data must have the same number of time steps for all variables." + ) + + if first_df is not None and not all( + df.index.equals(first_df.index) + for df in scenario.values() + if isinstance(df, pd.DataFrame) + ): + raise ValueError("Scenario data must have the same index for all variables.") + + return scenario + + elif isinstance(scenario, pd.Series | pd.DataFrame | np.ndarray | list | tuple): + # A user might be lazy and pass a simple list when there is only one exogenous variable. + if isinstance(scenario, list | tuple) or ( + isinstance(scenario, np.ndarray) and scenario.ndim == 1 + ): + scenario = np.array(scenario).reshape(-1, 1) + + if name is None: + # name should only be None on the first non-recursive call. We only arrive to this branch in that case + # if a non-dictionary was passed, which in turn should only happen if only a single exogenous data + # needs to be set. + if len(self._exog_names) > 1: + raise ValueError( + "Multiple exogenous variables were used to fit the model. Provide a dictionary of " + "scenario data instead." + ) + name = self._exog_names[0] + + # Omit dataframe from this basic shape check so we can give more detailed information about missing columns + # in the next check + if not isinstance(scenario, pd.DataFrame | pd.Series) and scenario.shape[1] != len( + coords[name] + ): + raise ValueError( + f"Scenario data for variable '{name}' has the wrong number of columns. Expected " + f"{len(coords[name])}, got {scenario.shape[1]}" + ) + + if isinstance(scenario, pd.Series): + if len(coords[name]) > 1: + raise ValueError( + f"Scenario data for variable '{name}' has the wrong number of columns. Expected " + f"{len(coords[name])}, got 1" + ) + + if isinstance(scenario, pd.DataFrame): + expected_cols = coords[name] + cols = scenario.columns + missing_columns = sorted(list(set(expected_cols) - set(cols))) + if len(missing_columns) > 0: + suffix = "s" if len(missing_columns) > 1 else "" + raise ValueError( + f"Scenario data for variable '{name}' is missing the following column{suffix}: " + f"{', '.join(missing_columns)}" + ) + + extra_columns = sorted(list(set(cols) - set(expected_cols))) + if len(extra_columns) > 0: + suffix = "s" if len(extra_columns) > 1 else "" + verb = "is" if len(extra_columns) == 1 else "are" + raise ValueError( + f"Scenario data for variable '{name}' contains the following extra column{suffix} " + f"that {verb} not used by the model: " + f"{', '.join(extra_columns)}" + ) + + if not (a == b for a, b in zip(expected_cols, cols)) and verbose: + _log.warning( + f"Scenario data for {name} has a different column order than the data used to fit the " + f"model. Columns will be automatically re-ordered. Ensure consistent ordering to avoid " + f"silent errors." + ) + scenario = scenario[expected_cols] + + return scenario + + @staticmethod + def _build_forecast_index( + time_index: pd.RangeIndex | pd.DatetimeIndex, + start: int | pd.Timestamp | None = None, + end: int | pd.Timestamp = None, + periods: int | None = None, + use_scenario_index: bool = False, + scenario: pd.DataFrame | np.ndarray | None = None, + ) -> tuple[int | pd.Timestamp, pd.RangeIndex | pd.DatetimeIndex]: + """ + Construct a pandas Index for the requested forecast horizon. + + Parameters + ---------- + time_index: pd.RangeIndex or pd.DatetimeIndex + Index of the data used to fit the model + start: int or pd.Timestamp, optional + Date from which to begin forecasting. If using a datetime index, integer start will be interpreted + as a positional index. Otherwise, start must be found inside the time_index + end: int or pd.Timestamp, optional + Date at which to end forecasting. If using a datetime index, end must be a timestamp. + periods: int, optional + Number of periods to forecast + scenario: pd.DataFrame, np.ndarray, optional + Scenario data to use for forecasting. If provided, the index of the scenario data will be used as the + forecast index. If provided, start, end, and periods will be ignored. + use_scenario_index: bool, default False + If True, the index of the scenario data will be used as the forecast index. + + + Returns + ------- + start: int | pd.TimeStamp + The starting date index or time step from which to generate the forecasts. + + forecast_index: pd.DatetimeIndex or pd.RangeIndex + Index for the forecast results + """ + + def get_or_create_index(x, time_index, start=None): + if isinstance(x, pd.DataFrame | pd.Series): + return x.index + elif isinstance(x, dict): + return get_or_create_index(next(iter(x.values())), time_index, start) + elif isinstance(x, np.ndarray | list | tuple): + if start is None: + raise ValueError( + "Provided scenario has no index and no start date was provided. This combination " + "is ambiguous. Please provide a start date, or add an index to the scenario." + ) + is_datetime_index = isinstance(time_index, pd.DatetimeIndex) + n = x.shape[0] if isinstance(x, np.ndarray) else len(x) + + if isinstance(start, int): + start = time_index[start] + if is_datetime_index: + return pd.date_range(start, periods=n, freq=time_index.freq) + return pd.RangeIndex(start, n + start, step=1, dtype="int") + + else: + raise ValueError(f"{type(x)} is not a valid type for scenario data.") + + x0_idx = None + + if use_scenario_index: + forecast_index = get_or_create_index(scenario, time_index, start) + is_datetime = isinstance(forecast_index, pd.DatetimeIndex) + + # If the user provided an index, we want to take it as-is (without removing the start value). Instead, + # step one back and use this as the start value. + delta = forecast_index.freq if is_datetime else 1 + x0_idx = forecast_index[0] - delta + + else: + # Otherwise, build an index. It will be a DateTime index if we have all the necessary information, otherwise + # use a range index. + is_datetime = isinstance(time_index, pd.DatetimeIndex) + forecast_index = None + + if is_datetime: + freq = time_index.freq + if isinstance(start, int): + start = time_index[start] + if isinstance(end, int): + raise ValueError( + "end must be a timestamp if using a datetime index. To specify a number of " + "timesteps from the start date, use the periods argument instead." + ) + if end is not None: + forecast_index = pd.date_range(start, end=end, freq=freq) + if periods is not None: + # date_range includes both the start and end date, but we're going to pop off the start later + # (it will be interpreted as x0). So we need to add 1 to the periods so the user gets "periods" + # number of forecasts back + forecast_index = pd.date_range(start, periods=periods + 1, freq=freq) + + else: + # If the user provided a positive integer as start, directly interpret it as the start time. If its + # negative, interpret it as a positional index. + if start < 0: + start = time_index[start] + if end is not None: + forecast_index = pd.RangeIndex(start, end, step=1, dtype="int") + if periods is not None: + forecast_index = pd.RangeIndex(start, start + periods + 1, step=1, dtype="int") + + if is_datetime: + if forecast_index.freq != time_index.freq: + raise ValueError( + "The frequency of the forecast index must match the frequency on the data used " + f"to fit the model. Got {forecast_index.freq}, expected {time_index.freq}" + ) + + if x0_idx is None: + x0_idx, forecast_index = forecast_index[0], forecast_index[1:] + if x0_idx in forecast_index: + raise ValueError("x0_idx should not be in the forecast index") + if x0_idx not in time_index: + raise ValueError("start must be in the data index used to fit the model.") + + # The starting value should not be included in the forecast index. It will be used only to define x0 and P0, + # and no forecast will be associated with it. + return x0_idx, forecast_index + + def _finalize_scenario_initialization( + self, + scenario: pd.DataFrame | np.ndarray | dict[str, pd.DataFrame | np.ndarray] | None, + forecast_index: pd.RangeIndex | pd.DatetimeIndex, + name=None, + ): + try: + var_to_dims = {key: info["dims"][1:] for key, info in self.data_info.items()} + except NotImplementedError: + return scenario + + if any(len(dims) > 1 for dims in var_to_dims.values()): + raise NotImplementedError(">2d exogenous data is not yet supported.") + coords = { + var: self._fit_coords[dim[0]] + for var, dim in var_to_dims.items() + if dim[0] in self._fit_coords + } + + if scenario is None: + return scenario + + if isinstance(scenario, dict): + for name, data in scenario.items(): + scenario[name] = self._finalize_scenario_initialization(data, forecast_index, name) + return scenario + + # This was already checked as valid + name = self._exog_names[0] if name is None else name + + # Small tidying up in the case we just have a single scenario that's already a dataframe. + if isinstance(scenario, pd.DataFrame | pd.Series): + if isinstance(scenario, pd.Series): + scenario = scenario.to_frame(name=coords[name][0]) + if not scenario.index.equals(forecast_index): + scenario.index = forecast_index + + # lists and tuples were handled during validation, along with shape check, so just cast arrays to dataframes + # with the correct index and columns + if isinstance(scenario, np.ndarray): + scenario = pd.DataFrame(scenario, index=forecast_index, columns=coords[name]) + + return scenario + def forecast( self, idata: InferenceData, - start: int | pd.Timestamp, + start: int | pd.Timestamp | None = None, periods: int | None = None, end: int | pd.Timestamp = None, + scenario: pd.DataFrame | np.ndarray | dict[str, pd.DataFrame | np.ndarray] | None = None, + use_scenario_index: bool = False, filter_output="smoothed", random_seed: RandomState | None = None, + verbose: bool = True, **kwargs, ) -> InferenceData: """ @@ -1526,22 +1904,37 @@ def forecast( idata : InferenceData An Arviz InferenceData object containing the posterior distribution over model parameters. - start : Union[int, pd.Timestamp] + start : int or pd.Timestamp, optional The starting date index or time step from which to generate the forecasts. If the data provided to `PyMCStateSpace.build_statespace_graph` had a datetime index, `start` should be a datetime. If using integer time series, `start` should be an integer indicating the starting time step. In either case, `start` should be in the data index used to build the statespace graph. - periods : Optional[int], default=None + If start is None, the last value on the data's index will be used. + + periods : int, optional The number of time steps to forecast into the future. If `periods` is specified, the `end` parameter will be ignored. If `None`, then the `end` parameter must be provided. - end : Union[int, pd.Timestamp], default=None + end : int or pd.Timestamp, optional The ending date index or time step up to which to generate the forecasts. If the data provided to `PyMCStateSpace.build_statespace_graph` had a datetime index, `start` should be a datetime. If using integer time series, `end` should be an integer indicating the ending time step. If `end` is provided, the `periods` parameter will be ignored. + scenario: pd.Dataframe or np.ndarray, optional + Exogenous variables to use for scenario-based forecasting. Must be a 2d array-like, with second dimension + equal to the number of exogenous variables. If start, end, or periods are specified, the first dimension + must conform with these settings. Otherwise, the index of the scenario data will be used to set the + number of forecast steps. If the index of the forecast scenairo is a pandas DateTimeIndex, its frequency + must match the frequency of the data used to fit the model. Otherwise, dates will be based on the number + of forecast steps and the data. + + use_scenario_index: bool, default False + If True, the index of the scenario data will be used to determine the forecast period. In this case, + the start, end, and periods arguments will be ignored. If True, the scenario data must be a DataFrame, + otherwise an error will be raised. + filter_output : str, default="smoothed" The type of Kalman Filter output used to initialize the forecasts. The 0th timestep of the forecast will be sampled from x[0] ~ N(filter_output_mean[start], filter_output_covariance[start]). Default is "smoothed", @@ -1550,6 +1943,9 @@ def forecast( random_seed : int, RandomState or Generator, optional Seed for the random number generator. + verbose: bool, default=True + Whether to print diagnostic information about forecasting. + kwargs: Additional keyword arguments are passed to pymc.sample_posterior_predictive @@ -1566,51 +1962,56 @@ def forecast( the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`. """ + filter_time_dim = TIME_DIM + _validate_filter_arg(filter_output) - if periods is None and end is None: - raise ValueError("Must specify one of either periods or end") - if periods is not None and end is not None: - raise ValueError("Must specify exactly one of either periods or end") - if self._needs_exog_data: - raise ValueError( - "Scenario-based forcasting with exogenous variables not currently supported" + time_index = self._get_fit_time_index() + + if start is None and verbose: + _log.warning( + "No start date provided. Using the last date in the data index. To silence this warning, " + "explicitly pass a start date or set verbose = False" ) + start = time_index[-1] - temp_coords = self._fit_coords.copy() + if self._needs_exog_data and not isinstance(scenario, dict): + if len(self.data_names) > 1: + raise ValueError( + "Model needs more than one exogenous data to do forecasting. In this case, you must " + "pass a dictionary of scenario data." + ) + [data_name] = self.data_names + scenario = {data_name: scenario} + + scenario: dict = self._validate_scenario_data(scenario, verbose=verbose) + + self._validate_forecast_args( + time_index=time_index, + start=start, + end=end, + periods=periods, + scenario=scenario, + use_scenario_index=use_scenario_index, + verbose=verbose, + ) - filter_time_dim = TIME_DIM + t0, forecast_index = self._build_forecast_index( + time_index=time_index, + start=start, + end=end, + periods=periods, + scenario=scenario, + use_scenario_index=use_scenario_index, + ) + scenario = self._finalize_scenario_initialization(scenario, forecast_index) + temp_coords = self._fit_coords.copy() dims = None if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] - time_index = temp_coords[filter_time_dim] - - if start not in time_index: - raise ValueError("Start date is not in the provided data") - - is_datetime = isinstance(time_index[0], pd.Timestamp) - - forecast_index = None - - if is_datetime: - time_index = pd.DatetimeIndex(time_index) - freq = time_index.inferred_freq - - if end is not None: - forecast_index = pd.date_range(start, end=end, freq=freq) - if periods is not None: - forecast_index = pd.date_range(start, periods=periods, freq=freq) - t0 = forecast_index[0] - - else: - if end is not None: - forecast_index = np.arange(start, end, dtype="int") - if periods is not None: - forecast_index = np.arange(start, start + periods, dtype="int") - t0 = forecast_index[0] - t0_idx = np.flatnonzero(time_index == t0)[0] + temp_coords["data_time"] = time_index temp_coords[TIME_DIM] = forecast_index @@ -1620,22 +2021,11 @@ def forecast( cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM] with pm.Model(coords=temp_coords) as forecast_model: - ( - [ - x0, - P0, - c, - d, - T, - Z, - R, - H, - Q, - ], - grouped_outputs, - ) = self._kalman_filter_outputs_from_dummy_graph(data_dims=["data_time", OBS_STATE_DIM]) - group_idx = FILTER_OUTPUT_TYPES.index(filter_output) + (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( + data_dims=["data_time", OBS_STATE_DIM], + ) + group_idx = FILTER_OUTPUT_TYPES.index(filter_output) mu, cov = grouped_outputs[group_idx] x0 = pm.Deterministic( @@ -1645,22 +2035,28 @@ def forecast( "P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None ) + if scenario is not None: + sub_dict = { + forecast_model[data_name]: pt.as_tensor_variable( + scenario.get(data_name), name=data_name + ) + for data_name in self.data_names + } + + matrices = graph_replace(matrices, replace=sub_dict, strict=True) + [setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)] + _ = LinearGaussianStateSpace( "forecast", x0, P0, - c, - d, - T, - Z, - R, - H, - Q, - steps=len(forecast_index[:-1]), + *matrices, + steps=len(forecast_index), dims=dims, mode=self._fit_mode, sequence_names=self.kalman_filter.seq_names, k_endog=self.k_endog, + append_x0=False, ) forecast_model.rvs_to_initial_values = { @@ -1789,16 +2185,16 @@ def impulse_response_function( if use_posterior_cov: Q = post_Q if orthogonalize_shocks: - Q = pt.linalg.cholesky(Q) + Q = pt.linalg.cholesky(Q) / pt.diag(Q) elif shock_cov is not None: Q = pt.as_tensor_variable(shock_cov) if orthogonalize_shocks: - Q = pt.linalg.cholesky(Q) + Q = pt.linalg.cholesky(Q) / pt.diag(Q) if shock_trajectory is None: shock_trajectory = pt.zeros((n_steps, self.k_posdef)) if Q is not None: - init_shock = pm.MvNormal("initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM]) + init_shock = MvNormalSVD("initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM]) else: init_shock = pm.Deterministic( "initial_shock", diff --git a/pymc_experimental/statespace/filters/distributions.py b/pymc_experimental/statespace/filters/distributions.py index 1b4895081..d3b70c847 100644 --- a/pymc_experimental/statespace/filters/distributions.py +++ b/pymc_experimental/statespace/filters/distributions.py @@ -111,6 +111,7 @@ def __new__( steps=None, mode=None, sequence_names=None, + append_x0=True, **kwargs, ): # Ignore dims in support shape because they are just passed along to the "observed" and "latent" distributions @@ -138,12 +139,27 @@ def __new__( steps=steps, mode=mode, sequence_names=sequence_names, + append_x0=append_x0, **kwargs, ) @classmethod def dist( - cls, a0, P0, c, d, T, Z, R, H, Q, steps=None, mode=None, sequence_names=None, **kwargs + cls, + a0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + steps=None, + mode=None, + sequence_names=None, + append_x0=True, + **kwargs, ): steps = get_support_shape_1d( support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=0 @@ -155,11 +171,31 @@ def dist( steps = pt.as_tensor_variable(intX(steps), ndim=0) return super().dist( - [a0, P0, c, d, T, Z, R, H, Q, steps], mode=mode, sequence_names=sequence_names, **kwargs + [a0, P0, c, d, T, Z, R, H, Q, steps], + mode=mode, + sequence_names=sequence_names, + append_x0=append_x0, + **kwargs, ) @classmethod - def rv_op(cls, a0, P0, c, d, T, Z, R, H, Q, steps, size=None, mode=None, sequence_names=None): + def rv_op( + cls, + a0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + steps, + size=None, + mode=None, + sequence_names=None, + append_x0=True, + ): if sequence_names is None: sequence_names = [] @@ -239,8 +275,12 @@ def step_fn(*args): strict=True, ) - statespace_ = pt.concatenate([init_dist_[None], statespace], axis=0) - statespace_ = pt.specify_shape(statespace_, (steps + 1, None)) + if append_x0: + statespace_ = pt.concatenate([init_dist_[None], statespace], axis=0) + statespace_ = pt.specify_shape(statespace_, (steps + 1, None)) + else: + statespace_ = statespace + statespace_ = pt.specify_shape(statespace_, (steps, None)) (ss_rng,) = tuple(updates.values()) linear_gaussian_ss_op = LinearGaussianStateSpaceRV( @@ -276,6 +316,7 @@ def __new__( k_endog=None, sequence_names=None, mode=None, + append_x0=True, **kwargs, ): dims = kwargs.pop("dims", None) @@ -304,9 +345,10 @@ def __new__( steps=steps, mode=mode, sequence_names=sequence_names, + append_x0=append_x0, **kwargs, ) - latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + 1, None)) + latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None)) if k_endog is None: k_endog = cls._get_k_endog(H) latent_slice = slice(None, -k_endog) diff --git a/pymc_experimental/statespace/filters/kalman_filter.py b/pymc_experimental/statespace/filters/kalman_filter.py index c2bfd2f39..0d955d029 100644 --- a/pymc_experimental/statespace/filters/kalman_filter.py +++ b/pymc_experimental/statespace/filters/kalman_filter.py @@ -642,15 +642,15 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): PZT = P.dot(Z.T) F = Z.dot(PZT) + stabilize(H, self.cov_jitter) - F_inv = pt.linalg.solve(F, self.eye_endog, assume_a="pos", check_finite=False) - - K = PZT.dot(F_inv) + K = pt.linalg.solve(F.T, PZT.T, assume_a="pos", check_finite=False).T I_KZ = self.eye_states - K.dot(Z) a_filtered = a + K.dot(v) P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) - inner_term = matrix_dot(v.T, F_inv, v) + F_inv_v = pt.linalg.solve(F, v, assume_a="pos", check_finite=False) + inner_term = v.T @ F_inv_v + F_logdet = pt.log(pt.linalg.det(F)) ll = pt.switch( diff --git a/pymc_experimental/statespace/models/VARMAX.py b/pymc_experimental/statespace/models/VARMAX.py index 3dab978a0..16bc81097 100644 --- a/pymc_experimental/statespace/models/VARMAX.py +++ b/pymc_experimental/statespace/models/VARMAX.py @@ -220,11 +220,11 @@ def param_info(self) -> dict[str, dict[str, Any]]: "constraints": "Positive Semi-definite", }, "ar_params": { - "shape": (self.k_states, self.p, self.k_states), + "shape": (self.k_endog, self.p, self.k_endog), "constraints": "None", }, "ma_params": { - "shape": (self.k_states, self.q, self.k_states), + "shape": (self.k_endog, self.q, self.k_endog), "constraints": "None", }, } diff --git a/tests/statespace/test_statespace.py b/tests/statespace/test_statespace.py index e0062933a..83d0babca 100644 --- a/tests/statespace/test_statespace.py +++ b/tests/statespace/test_statespace.py @@ -1,4 +1,7 @@ +from functools import partial + import numpy as np +import pandas as pd import pymc as pm import pytensor import pytensor.tensor as pt @@ -14,7 +17,7 @@ MATRIX_NAMES, SMOOTHER_OUTPUT_NAMES, ) -from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import +from tests.statespace.utilities.shared_fixtures import ( rng, ) from tests.statespace.utilities.test_helpers import ( @@ -28,18 +31,26 @@ ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES -def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False): +def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None): class StateSpace(PyMCStateSpace): def make_symbolic_graph(self): pass - return StateSpace( + @property + def data_info(self): + return data_info + + ss = StateSpace( k_states=k_states, k_endog=k_endog, k_posdef=k_posdef, filter_type=filter_type, verbose=verbose, ) + ss._needs_exog_data = data_info is not None + ss._exog_names = list(data_info.keys()) if data_info is not None else [] + + return ss @pytest.fixture(scope="session") @@ -103,6 +114,18 @@ def pymc_mod(ss_mod): return pymc_mod +@pytest.fixture(scope="session") +def ss_mod_no_exog(rng): + ll = st.LevelTrendComponent(order=2, innovations_order=1) + return ll.build() + + +@pytest.fixture(scope="session") +def ss_mod_no_exog_dt(rng): + ll = st.LevelTrendComponent(order=2, innovations_order=1) + return ll.build() + + @pytest.fixture(scope="session") def exog_ss_mod(rng): ll = st.LevelTrendComponent() @@ -132,6 +155,42 @@ def exog_pymc_mod(exog_ss_mod, rng): return m +@pytest.fixture(scope="session") +def pymc_mod_no_exog(ss_mod_no_exog, rng): + y = pd.DataFrame(rng.normal(size=(100, 1)).astype(floatX), columns=["y"]) + + with pm.Model(coords=ss_mod_no_exog.coords) as m: + initial_trend = pm.Normal("initial_trend", dims=["trend_state"]) + P0_sigma = pm.Exponential("P0_sigma", 1) + P0 = pm.Deterministic( + "P0", pt.eye(ss_mod_no_exog.k_states) * P0_sigma, dims=["state", "state_aux"] + ) + sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) + ss_mod_no_exog.build_statespace_graph(y) + + return m + + +@pytest.fixture(scope="session") +def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng): + y = pd.DataFrame( + rng.normal(size=(100, 1)).astype(floatX), + columns=["y"], + index=pd.date_range("2020-01-01", periods=100, freq="D"), + ) + + with pm.Model(coords=ss_mod_no_exog_dt.coords) as m: + initial_trend = pm.Normal("initial_trend", dims=["trend_state"]) + P0_sigma = pm.Exponential("P0_sigma", 1) + P0 = pm.Deterministic( + "P0", pt.eye(ss_mod_no_exog_dt.k_states) * P0_sigma, dims=["state", "state_aux"] + ) + sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) + ss_mod_no_exog_dt.build_statespace_graph(y) + + return m + + @pytest.fixture(scope="session") def idata(pymc_mod, rng): with pymc_mod: @@ -151,6 +210,24 @@ def idata_exog(exog_pymc_mod, rng): return idata +@pytest.fixture(scope="session") +def idata_no_exog(pymc_mod_no_exog, rng): + with pymc_mod_no_exog: + idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) + idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) + idata.extend(idata_prior) + return idata + + +@pytest.fixture(scope="session") +def idata_no_exog_dt(pymc_mod_no_exog_dt, rng): + with pymc_mod_no_exog_dt: + idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng) + idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng) + idata.extend(idata_prior) + return idata + + def test_invalid_filter_name_raises(): msg = "The following are valid filter types: " + ", ".join(list(FILTER_FACTORY.keys())) with pytest.raises(NotImplementedError, match=msg): @@ -278,25 +355,470 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng): assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values)) +def _make_time_idx(mod, use_datetime_index=True): + if use_datetime_index: + mod._fit_coords["time"] = nile.index + time_idx = nile.index + else: + mod._fit_coords["time"] = nile.reset_index().index + time_idx = pd.RangeIndex(start=0, stop=nile.shape[0], step=1) + + return time_idx + + +@pytest.mark.parametrize("use_datetime_index", [True, False]) +def test_bad_forecast_arguments(use_datetime_index, caplog): + ss_mod = make_statespace_mod( + k_endog=1, k_posdef=1, k_states=2, filter_type="standard", verbose=False + ) + + # Not-fit model raises + ss_mod._fit_coords = dict() + with pytest.raises(ValueError, match="Has this model been fit?"): + ss_mod._get_fit_time_index() + + time_idx = _make_time_idx(ss_mod, use_datetime_index) + + # Start value not in time index + match = ( + "Datetime start must be in the data index used to fit the model" + if use_datetime_index + else "Integer start must be within the range of the data index used to fit the model." + ) + with pytest.raises(ValueError, match=match): + start = time_idx.shift(10)[-1] if use_datetime_index else time_idx[-1] + 11 + ss_mod._validate_forecast_args(time_index=time_idx, start=start, periods=10) + + # End value cannot be inferred + with pytest.raises(ValueError, match="Must specify one of either periods or end"): + start = time_idx[-1] + ss_mod._validate_forecast_args(time_index=time_idx, start=start) + + # Unnecessary args warn on verbose + start = time_idx[-1] + forecast_idx = pd.date_range(start=start, periods=10, freq="YS-JAN") + scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2]) + + ss_mod._validate_forecast_args( + time_index=time_idx, start=start, periods=10, scenario=scenario, use_scenario_index=True + ) + last_message = caplog.messages[-1] + assert "start, end, and periods arguments are ignored" in last_message + + # Verbose=False silences warning + ss_mod._validate_forecast_args( + time_index=time_idx, + start=start, + periods=10, + scenario=scenario, + use_scenario_index=True, + verbose=False, + ) + assert len(caplog.messages) == 1 + + +@pytest.mark.parametrize("use_datetime_index", [True, False]) +def test_forecast_index(use_datetime_index): + ss_mod = make_statespace_mod( + k_endog=1, k_posdef=1, k_states=2, filter_type="standard", verbose=False + ) + ss_mod._fit_coords = dict() + time_idx = _make_time_idx(ss_mod, use_datetime_index) + + # From start and end + start = time_idx[-1] + delta = pd.DateOffset(years=10) if use_datetime_index else 11 + end = start + delta + + x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, end=end) + assert start not in forecast_idx + assert x0_index == start + assert forecast_idx.shape == (10,) + + # From start and periods + start = time_idx[-1] + periods = 10 + + x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods) + assert start not in forecast_idx + assert x0_index == start + assert forecast_idx.shape == (10,) + + # From integer start + start = 10 + x0_index, forecast_idx = ss_mod._build_forecast_index(time_idx, start=start, periods=periods) + delta = forecast_idx.freq if use_datetime_index else 1 + + assert x0_index == time_idx[start] + assert forecast_idx.shape == (10,) + assert (forecast_idx == time_idx[start + 1 : start + periods + 1]).all() + + # From scenario index + scenario = pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2]) + new_start, forecast_idx = ss_mod._build_forecast_index( + time_index=time_idx, scenario=scenario, use_scenario_index=True + ) + assert x0_index not in forecast_idx + assert x0_index == (forecast_idx[0] - delta) + assert forecast_idx.shape == (10,) + assert forecast_idx.equals(scenario.index) + + # From dictionary of scenarios + scenario = {"a": pd.DataFrame(0, index=forecast_idx, columns=[0, 1, 2])} + x0_index, forecast_idx = ss_mod._build_forecast_index( + time_index=time_idx, scenario=scenario, use_scenario_index=True + ) + assert x0_index == (forecast_idx[0] - delta) + assert forecast_idx.shape == (10,) + assert forecast_idx.equals(scenario["a"].index) + + +@pytest.mark.parametrize( + "data_type", + [pd.Series, pd.DataFrame, np.array, list, tuple], + ids=["series", "dataframe", "array", "list", "tuple"], +) +def test_validate_scenario(data_type): + if data_type is pd.DataFrame: + # Ensure dataframes have the correct column name + data_type = partial(pd.DataFrame, columns=["column_1"]) + + # One data case + data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}} + ss_mod = make_statespace_mod( + k_endog=1, + k_posdef=1, + k_states=2, + filter_type="standard", + verbose=False, + data_info=data_info, + ) + ss_mod._fit_coords = dict(features_a=["column_1"]) + + scenario = data_type(np.zeros(10)) + scenario = ss_mod._validate_scenario_data(scenario) + + # Lists and tuples are cast to 2d arrays + if data_type in [tuple, list]: + assert isinstance(scenario, np.ndarray) + assert scenario.shape == (10, 1) + + # A one-item dictionary should also work + scenario = {"a": scenario} + ss_mod._validate_scenario_data(scenario) + + # Now data has to be a dictionary + data_info.update({"b": {"shape": (None, 1), "dims": ("time", "features_b")}}) + ss_mod = make_statespace_mod( + k_endog=1, + k_posdef=1, + k_states=2, + filter_type="standard", + verbose=False, + data_info=data_info, + ) + ss_mod._fit_coords = dict(features_a=["column_1"], features_b=["column_1"]) + + scenario = {"a": data_type(np.zeros(10)), "b": data_type(np.zeros(10))} + ss_mod._validate_scenario_data(scenario) + + # Mixed data types + data_info.update({"a": {"shape": (None, 10), "dims": ("time", "features_a")}}) + ss_mod = make_statespace_mod( + k_endog=1, + k_posdef=1, + k_states=2, + filter_type="standard", + verbose=False, + data_info=data_info, + ) + ss_mod._fit_coords = dict( + features_a=[f"column_{i}" for i in range(10)], features_b=["column_1"] + ) + + scenario = { + "a": pd.DataFrame(np.zeros((10, 10)), columns=ss_mod._fit_coords["features_a"]), + "b": data_type(np.arange(10)), + } + + ss_mod._validate_scenario_data(scenario) + + +@pytest.mark.parametrize( + "data_type", + [pd.Series, pd.DataFrame, np.array, list, tuple], + ids=["series", "dataframe", "array", "list", "tuple"], +) +@pytest.mark.parametrize("use_datetime_index", [True, False]) +def test_finalize_scenario_single(data_type, use_datetime_index): + if data_type is pd.DataFrame: + # Ensure dataframes have the correct column name + data_type = partial(pd.DataFrame, columns=["column_1"]) + + data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}} + ss_mod = make_statespace_mod( + k_endog=1, + k_posdef=1, + k_states=2, + filter_type="standard", + verbose=False, + data_info=data_info, + ) + ss_mod._fit_coords = dict(features_a=["column_1"]) + + time_idx = _make_time_idx(ss_mod, use_datetime_index) + + scenario = data_type(np.zeros((10,))) + + scenario = ss_mod._validate_scenario_data(scenario) + t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10) + scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx) + + assert isinstance(scenario, pd.DataFrame) + assert scenario.index.equals(forecast_idx) + assert scenario.columns == ["column_1"] + + +@pytest.mark.parametrize( + "data_type", + [pd.Series, pd.DataFrame, np.array, list, tuple], + ids=["series", "dataframe", "array", "list", "tuple"], +) +@pytest.mark.parametrize("use_datetime_index", [True, False]) +@pytest.mark.parametrize("use_scenario_index", [True, False]) +def test_finalize_secenario_dict(data_type, use_datetime_index, use_scenario_index): + data_info = { + "a": {"shape": (None, 1), "dims": ("time", "features_a")}, + "b": {"shape": (None, 2), "dims": ("time", "features_b")}, + } + ss_mod = make_statespace_mod( + k_endog=1, + k_posdef=1, + k_states=2, + filter_type="standard", + verbose=False, + data_info=data_info, + ) + ss_mod._fit_coords = dict(features_a=["column_1"], features_b=["column_1", "column_2"]) + time_idx = _make_time_idx(ss_mod, use_datetime_index) + + initial_index = ( + pd.date_range(start=time_idx[-1], periods=10, freq=time_idx.freq) + if use_datetime_index + else pd.RangeIndex(time_idx[-1], time_idx[-1] + 10, 1) + ) + + if data_type is pd.DataFrame: + # Ensure dataframes have the correct column name + data_type = partial(pd.DataFrame, columns=["column_1"], index=initial_index) + elif data_type is pd.Series: + data_type = partial(pd.Series, index=initial_index) + + scenario = { + "a": data_type(np.zeros((10,))), + "b": pd.DataFrame( + np.zeros((10, 2)), columns=ss_mod._fit_coords["features_b"], index=initial_index + ), + } + + scenario = ss_mod._validate_scenario_data(scenario) + + if use_scenario_index and data_type not in [np.array, list, tuple]: + t0, forecast_idx = ss_mod._build_forecast_index( + time_idx, scenario=scenario, periods=10, use_scenario_index=True + ) + elif use_scenario_index and data_type in [np.array, list, tuple]: + t0, forecast_idx = ss_mod._build_forecast_index( + time_idx, scenario=scenario, start=-1, periods=10, use_scenario_index=True + ) + else: + t0, forecast_idx = ss_mod._build_forecast_index(time_idx, start=time_idx[-1], periods=10) + + scenario = ss_mod._finalize_scenario_initialization(scenario, forecast_index=forecast_idx) + + assert list(scenario.keys()) == ["a", "b"] + assert all(isinstance(value, pd.DataFrame) for value in scenario.values()) + assert all(value.index.equals(forecast_idx) for value in scenario.values()) + + +def test_invalid_scenarios(): + data_info = {"a": {"shape": (None, 1), "dims": ("time", "features_a")}} + ss_mod = make_statespace_mod( + k_endog=1, + k_posdef=1, + k_states=2, + filter_type="standard", + verbose=False, + data_info=data_info, + ) + ss_mod._fit_coords = dict(features_a=["column_1", "column_2"]) + + # Omitting the data raises + with pytest.raises( + ValueError, match="This model was fit using exogenous data. Forecasting cannot be performed" + ): + ss_mod._validate_scenario_data(None) + + # Giving a list, tuple, or Series when a matrix of data is expected should always raise + with pytest.raises( + ValueError, + match="Scenario data for variable 'a' has the wrong number of columns. " + "Expected 2, got 1", + ): + for data_type in [list, tuple, pd.Series]: + ss_mod._validate_scenario_data(data_type(np.zeros(10))) + ss_mod._validate_scenario_data({"a": data_type(np.zeros(10))}) + + # Providing irrevelant data raises + with pytest.raises( + ValueError, + match="Scenario data provided for variable 'jk lol', which is not an exogenous " "variable", + ): + ss_mod._validate_scenario_data({"jk lol": np.zeros(10)}) + + # Incorrect 2nd dimension of a non-dataframe + with pytest.raises( + ValueError, + match="Scenario data for variable 'a' has the wrong number of columns. Expected " + "2, got 1", + ): + scenario = np.zeros(10).tolist() + ss_mod._validate_scenario_data(scenario) + ss_mod._validate_scenario_data(tuple(scenario)) + + scenario = {"a": np.zeros(10).tolist()} + ss_mod._validate_scenario_data(scenario) + ss_mod._validate_scenario_data({"a": tuple(scenario["a"])}) + + # If a data frame is provided, it needs to have all columns + with pytest.raises( + ValueError, match="Scenario data for variable 'a' is missing the following column: column_2" + ): + scenario = pd.DataFrame(np.zeros((10, 1)), columns=["column_1"]) + ss_mod._validate_scenario_data(scenario) + + # Extra columns also raises + with pytest.raises( + ValueError, + match="Scenario data for variable 'a' contains the following extra columns " + "that are not used by the model: column_3, column_4", + ): + scenario = pd.DataFrame( + np.zeros((10, 4)), columns=["column_1", "column_2", "column_3", "column_4"] + ) + ss_mod._validate_scenario_data(scenario) + + # Wrong number of time steps raises + data_info = { + "a": {"shape": (None, 1), "dims": ("time", "features_a")}, + "b": {"shape": (None, 1), "dims": ("time", "features_b")}, + } + ss_mod = make_statespace_mod( + k_endog=1, + k_posdef=1, + k_states=2, + filter_type="standard", + verbose=False, + data_info=data_info, + ) + ss_mod._fit_coords = dict( + features_a=["column_1", "column_2"], features_b=["column_1", "column_2"] + ) + + with pytest.raises( + ValueError, match="Scenario data must have the same number of time steps for all variables" + ): + scenario = { + "a": pd.DataFrame(np.zeros((10, 2)), columns=ss_mod._fit_coords["features_a"]), + "b": pd.DataFrame(np.zeros((11, 2)), columns=ss_mod._fit_coords["features_b"]), + } + ss_mod._validate_scenario_data(scenario) + + +@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") @pytest.mark.parametrize("filter_output", ["predicted", "filtered", "smoothed"]) -def test_forecast(filter_output, ss_mod, idata, rng): - time_idx = idata.posterior.coords["time"].values - forecast_idata = ss_mod.forecast( - idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng +@pytest.mark.parametrize( + "mod_name, idata_name, start, end, periods", + [ + ("ss_mod_no_exog", "idata_no_exog", None, None, 10), + ("ss_mod_no_exog", "idata_no_exog", -1, None, 10), + ("ss_mod_no_exog", "idata_no_exog", 10, None, 10), + ("ss_mod_no_exog", "idata_no_exog", 10, 21, None), + ("ss_mod_no_exog_dt", "idata_no_exog_dt", None, None, 10), + ("ss_mod_no_exog_dt", "idata_no_exog_dt", -1, None, 10), + ("ss_mod_no_exog_dt", "idata_no_exog_dt", 10, None, 10), + ("ss_mod_no_exog_dt", "idata_no_exog_dt", 10, "2020-01-21", None), + ("ss_mod_no_exog_dt", "idata_no_exog_dt", "2020-03-01", "2020-03-11", None), + ("ss_mod_no_exog_dt", "idata_no_exog_dt", "2020-03-01", None, 10), + ], + ids=[ + "range_default", + "range_negative", + "range_int", + "range_end", + "datetime_default", + "datetime_negative", + "datetime_int", + "datetime_int_end", + "datetime_datetime_end", + "datetime_datetime", + ], +) +def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng, request): + mod = request.getfixturevalue(mod_name) + idata = request.getfixturevalue(idata_name) + time_idx = mod._get_fit_time_index() + is_datetime = isinstance(time_idx, pd.DatetimeIndex) + + if isinstance(start, str): + t0 = pd.Timestamp(start) + elif isinstance(start, int): + t0 = time_idx[start] + else: + t0 = time_idx[-1] + + delta = time_idx.freq if is_datetime else 1 + + forecast_idata = mod.forecast( + idata, start=start, end=end, periods=periods, filter_output=filter_output, random_seed=rng ) - assert forecast_idata.coords["time"].values.shape == (10,) + forecast_idx = forecast_idata.coords["time"].values + forecast_idx = pd.DatetimeIndex(forecast_idx) if is_datetime else pd.Index(forecast_idx) + + assert forecast_idx.shape == (10,) assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state") assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state") assert not np.any(np.isnan(forecast_idata.forecast_latent.values)) assert not np.any(np.isnan(forecast_idata.forecast_observed.values)) + assert forecast_idx[0] == (t0 + delta) + @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") -def test_forecast_fails_if_exog_needed(exog_ss_mod, idata_exog): - time_idx = idata_exog.observed_data.coords["time"].values - with pytest.xfail("Scenario-based forcasting with exogenous variables not currently supported"): - forecast_idata = exog_ss_mod.forecast( - idata_exog, start=time_idx[-1], periods=10, random_seed=rng - ) +@pytest.mark.parametrize("start", [None, -1, 10]) +def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start): + scenario = pd.DataFrame(np.zeros((10, 3)), columns=["a", "b", "c"]) + scenario.iloc[5, 0] = 1e9 + + forecast_idata = exog_ss_mod.forecast( + idata_exog, start=start, periods=10, random_seed=rng, scenario=scenario + ) + + components = exog_ss_mod.extract_components_from_idata(forecast_idata) + level = components.forecast_latent.sel(state="LevelTrend[level]") + betas = components.forecast_latent.sel(state=["exog[a]", "exog[b]", "exog[c]"]) + + scenario.index.name = "time" + scenario_xr = ( + scenario.unstack() + .to_xarray() + .rename({"level_0": "state"}) + .assign_coords(state=["exog[a]", "exog[b]", "exog[c]"]) + ) + + regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level + regression_effect_expected = (betas * scenario_xr).sum(dim=["state"]) + + assert_allclose(regression_effect, regression_effect_expected)