Skip to content

Commit 34709f2

Browse files
committed
provide the setup and breakdown for pytest fixtures
1 parent 5884a39 commit 34709f2

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

pymc/testing.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,3 +1034,54 @@ def mock_pymc_sample():
10341034
if "prior_predictive" in idata:
10351035
del idata["prior_predictive"]
10361036
return idata
1037+
1038+
1039+
def mock_sample_setup_and_breakdown():
1040+
"""Set up and tear down mocking of PyMC sampling functions for testing.
1041+
1042+
This function is designed to be used with pytest fixtures to temporarily replace
1043+
PyMC's sampling functionality with faster alternatives for testing purposes.
1044+
1045+
Effects during the fixture's active period:
1046+
* Replaces pm.sample with mock_sample, which uses prior predictive sampling
1047+
instead of MCMC
1048+
* Replaces pm.Flat with pm.Normal to avoid issues with unbounded priors
1049+
* Replaces pm.HalfFlat with pm.HalfNormal to avoid issues with semi-bounded priors
1050+
* Automatically restores all original functions after the test completes
1051+
1052+
Examples
1053+
--------
1054+
.. code-block:: python
1055+
1056+
import pytest
1057+
import pymc as pm
1058+
from pymc.testing import mock_sample_setup_and_breakdown
1059+
1060+
# Register as a pytest fixture
1061+
mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_breakdown)
1062+
1063+
1064+
# Use in a test function
1065+
def test_model_inference(mock_pymc_sample):
1066+
with pm.Model() as model:
1067+
x = pm.Normal("x", 0, 1)
1068+
# This will use mock_sample instead of actual MCMC
1069+
idata = pm.sample()
1070+
# Test with the inference data...
1071+
1072+
"""
1073+
import pymc as pm
1074+
1075+
original_flat = pm.Flat
1076+
original_half_flat = pm.HalfFlat
1077+
original_sample = pm.sample
1078+
1079+
pm.sample = mock_sample
1080+
pm.Flat = pm.Normal
1081+
pm.HalfFlat = pm.HalfNormal
1082+
1083+
yield
1084+
1085+
pm.sample = original_sample
1086+
pm.Flat = original_flat
1087+
pm.HalfFlat = original_half_flat

tests/test_testing.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
import pytest
1717

18-
from pymc.testing import Domain, mock_sample
18+
import pymc as pm
19+
20+
from pymc.testing import Domain, mock_sample, mock_sample_setup_and_breakdown
1921
from tests.models import simple_normal
2022

2123

@@ -56,3 +58,16 @@ def test_mock_sample(args, kwargs, expected_draws) -> None:
5658
assert "sample_stats" not in idata
5759

5860
assert idata.posterior.sizes == {"chain": 1, "draw": expected_draws}
61+
62+
63+
mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_breakdown)
64+
65+
66+
def test_fixture(mock_pymc_sample) -> None:
67+
# This has Flat distribution
68+
_, model, _ = simple_normal(bounded_prior=False)
69+
70+
with model:
71+
idata = pm.sample()
72+
73+
assert idata.posterior.sizes == {"chain": 1, "draw": 10}

0 commit comments

Comments
 (0)