Skip to content

Commit 36da40a

Browse files
committed
use positional arg for draws like in actual sample
1 parent 6fd8f45 commit 36da40a

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

pymc/testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
986986
raise AssertionError(f"RV found in graph: {rvs}")
987987

988988

989-
def mock_sample(*args, **kwargs):
989+
def mock_sample(draws: int = 10, **kwargs):
990990
"""Mock the pm.sample function by returning prior predictive samples as posterior.
991991
992992
Useful for testing models that use pm.sample without running MCMC sampling.
@@ -1015,7 +1015,7 @@ def mock_pymc_sample():
10151015
"""
10161016
random_seed = kwargs.get("random_seed", None)
10171017
model = kwargs.get("model", None)
1018-
draws = kwargs.get("draws", 10)
1018+
draws = kwargs.get("draws", draws)
10191019
n_chains = kwargs.get("chains", 1)
10201020
idata: InferenceData = pm.sample_prior_predictive(
10211021
model=model,

tests/test_testing.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,24 @@ def test_domain(values, edges, expectation):
3535
Domain(values, edges=edges)
3636

3737

38-
def test_mock_sample() -> None:
38+
@pytest.mark.parametrize(
39+
"args, kwargs, expected_draws",
40+
[
41+
pytest.param((), {}, 10, id="default"),
42+
pytest.param((100,), {}, 100, id="positional-draws"),
43+
pytest.param((), {"draws": 100}, 100, id="keyword-draws"),
44+
],
45+
)
46+
def test_mock_sample(args, kwargs, expected_draws) -> None:
3947
_, model, _ = simple_normal(bounded_prior=True)
4048

41-
idata = mock_sample(model=model)
49+
with model:
50+
idata = mock_sample(*args, **kwargs)
4251

4352
assert "posterior" in idata
4453
assert "observed_data" in idata
4554
assert "prior" not in idata
4655
assert "posterior_predictive" not in idata
4756
assert "sample_stats" not in idata
57+
58+
assert idata.posterior.sizes == {"chain": 1, "draw": expected_draws}

0 commit comments

Comments
 (0)