From 76c3ff0b8671184a8b12288ad083d3ca47377fd1 Mon Sep 17 00:00:00 2001 From: Goose Date: Tue, 4 Mar 2025 11:15:30 +1000 Subject: [PATCH] add return type overload for sample_posterior_predictive --- pymc/sampling/forward.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 91f1c6f69..b1f9c3989 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -493,6 +493,36 @@ def sample_prior_predictive( return pm.to_inference_data(prior=prior, **ikwargs) +@overload +def sample_posterior_predictive( + trace, + model: Model | None = None, + var_names: list[str] | None = None, + sample_dims: list[str] | None = None, + random_seed: RandomState = None, + progressbar: bool = True, + progressbar_theme: Theme | None = default_progress_theme, + return_inferencedata: Literal[True] = True, + extend_inferencedata: bool = False, + predictions: bool = False, + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, +) -> InferenceData: ... +@overload +def sample_posterior_predictive( + trace, + model: Model | None = None, + var_names: list[str] | None = None, + sample_dims: list[str] | None = None, + random_seed: RandomState = None, + progressbar: bool = True, + progressbar_theme: Theme | None = default_progress_theme, + return_inferencedata: Literal[False] = False, + extend_inferencedata: bool = False, + predictions: bool = False, + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, +) -> dict[str, np.ndarray]: ... def sample_posterior_predictive( trace, model: Model | None = None,