diff --git a/pymc/testing.py b/pymc/testing.py index 886177ef02..86acff4c19 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -1011,10 +1011,12 @@ def mock_pymc_sample(): model = kwargs.get("model", None) draws = kwargs.get("draws", draws) n_chains = kwargs.get("chains", 1) + var_names = kwargs.get("var_names", None) idata: InferenceData = pm.sample_prior_predictive( model=model, random_seed=random_seed, draws=draws, + var_names=var_names, ) idata.add_groups( diff --git a/tests/test_testing.py b/tests/test_testing.py index 105e2f6209..b38eb7e6fe 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -81,3 +81,18 @@ def test_fixture(mock_pymc_sample, dummy_model) -> None: posterior = idata.posterior assert posterior.sizes == {"chain": 1, "draw": 10} assert (posterior["half_flat"] >= 0).all() + + +def test_mock_pymc_sample_var_names(mock_pymc_sample): + with pm.Model() as model: + pm.Flat("flat") + pm.HalfFlat("half_flat") + pm.Flat("other_flat") + + with model: + idata = pm.sample(var_names=["flat", "half_flat"]) + assert set(idata.posterior.data_vars) == {"flat", "half_flat"} + + with model: + idata = pm.sample() + assert set(idata.posterior.data_vars) == {"flat", "half_flat", "other_flat"}