Skip to content

Commit 340e403

Browse files
authored
Allow for specification of 'var_names' in 'mock_sample' (#7906)
1 parent 9f653a6 commit 340e403

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

pymc/testing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,10 +1011,12 @@ def mock_pymc_sample():
10111011
model = kwargs.get("model", None)
10121012
draws = kwargs.get("draws", draws)
10131013
n_chains = kwargs.get("chains", 1)
1014+
var_names = kwargs.get("var_names", None)
10141015
idata: InferenceData = pm.sample_prior_predictive(
10151016
model=model,
10161017
random_seed=random_seed,
10171018
draws=draws,
1019+
var_names=var_names,
10181020
)
10191021

10201022
idata.add_groups(

tests/test_testing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,18 @@ def test_fixture(mock_pymc_sample, dummy_model) -> None:
8181
posterior = idata.posterior
8282
assert posterior.sizes == {"chain": 1, "draw": 10}
8383
assert (posterior["half_flat"] >= 0).all()
84+
85+
86+
def test_mock_pymc_sample_var_names(mock_pymc_sample):
87+
with pm.Model() as model:
88+
pm.Flat("flat")
89+
pm.HalfFlat("half_flat")
90+
pm.Flat("other_flat")
91+
92+
with model:
93+
idata = pm.sample(var_names=["flat", "half_flat"])
94+
assert set(idata.posterior.data_vars) == {"flat", "half_flat"}
95+
96+
with model:
97+
idata = pm.sample()
98+
assert set(idata.posterior.data_vars) == {"flat", "half_flat", "other_flat"}

0 commit comments

Comments
 (0)