From b5127178219619b1a9933d21f19a063b2345af66 Mon Sep 17 00:00:00 2001 From: Goose Date: Mon, 3 Mar 2025 12:42:31 +1000 Subject: [PATCH 1/5] correct sampled variables in sample_prior_predictive log call & add return type overloads --- pymc/sampling/forward.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 5f3c983b58..0f47325fe7 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -22,6 +22,7 @@ Any, TypeAlias, cast, + overload, ) import numpy as np @@ -360,6 +361,28 @@ def observed_dependent_deterministics(model: Model, extra_observeds=None): ] +@overload +def sample_prior_predictive( + draws: int = 500, + model: Model | None = None, + var_names: Iterable[str] | None = None, + random_seed: RandomState = None, + return_inferencedata: bool = True, + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, + samples: int | None = None, +) -> InferenceData: ... +@overload +def sample_prior_predictive( + draws: int = 500, + model: Model | None = None, + var_names: Iterable[str] | None = None, + random_seed: RandomState = None, + return_inferencedata: bool = False, + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, + samples: int | None = None, +) -> dict[str, np.ndarray]: ... def sample_prior_predictive( draws: int = 500, model: Model | None = None, @@ -449,7 +472,7 @@ def sample_prior_predictive( ) # All model variables have a name, but mypy does not know this - _log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore[arg-type, return-value] + _log.info(f"Sampling: {sorted(vars_to_sample, key=lambda var: var.name)}") # type: ignore[arg-type, return-value] values = zip(*(sampler_fn() for i in range(draws))) data = {k: np.stack(v) for k, v in zip(names, values)} From ec0597fc89cdf0e22e0b2e908e963f311374c8df Mon Sep 17 00:00:00 2001 From: Goose Date: Mon, 3 Mar 2025 13:08:13 +1000 Subject: [PATCH 2/5] correct overload type hinting --- pymc/sampling/forward.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 0f47325fe7..e44ffd6be5 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -20,6 +20,7 @@ from collections.abc import Callable, Iterable, Sequence from typing import ( Any, + Literal, TypeAlias, cast, overload, @@ -367,7 +368,7 @@ def sample_prior_predictive( model: Model | None = None, var_names: Iterable[str] | None = None, random_seed: RandomState = None, - return_inferencedata: bool = True, + return_inferencedata: Literal[True] = True, idata_kwargs: dict | None = None, compile_kwargs: dict | None = None, samples: int | None = None, @@ -378,7 +379,7 @@ def sample_prior_predictive( model: Model | None = None, var_names: Iterable[str] | None = None, random_seed: RandomState = None, - return_inferencedata: bool = False, + return_inferencedata: Literal[False] = False, idata_kwargs: dict | None = None, compile_kwargs: dict | None = None, samples: int | None = None, From 2d2c46c5e19dc936999b34a35d932bdf530bde03 Mon Sep 17 00:00:00 2001 From: Goose Date: Mon, 3 Mar 2025 13:57:18 +1000 Subject: [PATCH 3/5] don't list deterministics --- pymc/sampling/forward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index e44ffd6be5..8c8f6c3d28 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -473,7 +473,9 @@ def sample_prior_predictive( ) # All model variables have a name, but mypy does not know this - _log.info(f"Sampling: {sorted(vars_to_sample, key=lambda var: var.name)}") # type: ignore[arg-type, return-value] + _log.info( + f"Sampling: {sorted(volatile_basic_rvs.intersection(vars_to_sample), key=lambda var: var.name)}" # type: ignore[arg-type, return-value] + ) values = zip(*(sampler_fn() for i in range(draws))) data = {k: np.stack(v) for k, v in zip(names, values)} From a7fe08c2c51219c0606dd7d25ef2a19a2f68a759 Mon Sep 17 00:00:00 2001 From: Goose Date: Mon, 3 Mar 2025 14:02:18 +1000 Subject: [PATCH 4/5] add test for predictive variable only --- tests/sampling/test_forward.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 7bbcdc42b1..8200858ff8 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -862,11 +862,18 @@ def test_logging_sampled_basic_rvs_prior(self, caplog): assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")] caplog.clear() + # RV only with m: pm.sample_prior_predictive(draws=1, var_names=["x"]) assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x]")] caplog.clear() + # observed only + with m: + pm.sample_prior_predictive(draws=1, var_names=["z"]) + assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [z]")] + caplog.clear() + def test_logging_sampled_basic_rvs_posterior(self, caplog): with pm.Model() as m: x = pm.Normal("x") From 36e861eb9901aeffcb4515099a09a2dfbb301082 Mon Sep 17 00:00:00 2001 From: Goose Date: Mon, 3 Mar 2025 20:51:58 +1000 Subject: [PATCH 5/5] implement feedback --- pymc/sampling/forward.py | 4 +--- tests/sampling/test_forward.py | 7 ++++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 8c8f6c3d28..91f1c6f691 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -473,9 +473,7 @@ def sample_prior_predictive( ) # All model variables have a name, but mypy does not know this - _log.info( - f"Sampling: {sorted(volatile_basic_rvs.intersection(vars_to_sample), key=lambda var: var.name)}" # type: ignore[arg-type, return-value] - ) + _log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore[arg-type, return-value] values = zip(*(sampler_fn() for i in range(draws))) data = {k: np.stack(v) for k, v in zip(names, values)} diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 8200858ff8..d3b41bf667 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -857,21 +857,22 @@ def test_logging_sampled_basic_rvs_prior(self, caplog): y = pm.Deterministic("y", x + 1) z = pm.Normal("z", y, observed=0) + # all volatile RVs in model with m: pm.sample_prior_predictive(draws=1) assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")] caplog.clear() - # RV only + # `x` has no dependencies so will be sampled by itself with m: pm.sample_prior_predictive(draws=1, var_names=["x"]) assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x]")] caplog.clear() - # observed only + # `z` depends on `x` with m: pm.sample_prior_predictive(draws=1, var_names=["z"]) - assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [z]")] + assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")] caplog.clear() def test_logging_sampled_basic_rvs_posterior(self, caplog):