Skip to content

Commit a7a1b3f

Browse files
committed
test the testing
1 parent d7cedd7 commit a7a1b3f

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

tests/test_testing.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from contextlib import ExitStack as does_not_raise
1515

16+
import numpy as np
1617
import pytest
1718

1819
import pymc as pm
@@ -38,28 +39,47 @@ def test_domain(values, edges, expectation):
3839

3940

4041
@pytest.mark.parametrize(
41-
"args, kwargs, expected_size",
42+
"args, kwargs, expected_size, sample_stats",
4243
[
43-
pytest.param((), {}, (1, 10), id="default"),
44-
pytest.param((100,), {}, (1, 100), id="positional-draws"),
45-
pytest.param((), {"draws": 100}, (1, 100), id="keyword-draws"),
46-
pytest.param((100,), {"chains": 6}, (6, 100), id="chains"),
44+
pytest.param((), {}, (1, 10), None, id="default"),
45+
pytest.param((100,), {}, (1, 100), None, id="positional-draws"),
46+
pytest.param((), {"draws": 100}, (1, 100), None, id="keyword-draws"),
47+
pytest.param((100,), {"chains": 6}, (6, 100), None, id="chains"),
48+
pytest.param(
49+
(100,),
50+
{"chains": 6},
51+
(6, 100),
52+
{
53+
"diverging": np.zeros,
54+
"tree_depth": lambda size: np.random.choice(range(2, 10), size=size),
55+
},
56+
id="with_sample_stats",
57+
),
4758
],
4859
)
49-
def test_mock_sample(args, kwargs, expected_size) -> None:
60+
def test_mock_sample(args, kwargs, expected_size, sample_stats) -> None:
5061
expected_chains, expected_draws = expected_size
5162
_, model, _ = simple_normal(bounded_prior=True)
5263

5364
with model:
54-
idata = mock_sample(*args, **kwargs)
65+
idata = mock_sample(*args, **kwargs, sample_stats=sample_stats)
5566

5667
assert "posterior" in idata
5768
assert "observed_data" in idata
5869
assert "prior" not in idata
5970
assert "posterior_predictive" not in idata
60-
assert "sample_stats" not in idata
6171

62-
assert idata.posterior.sizes == {"chain": expected_chains, "draw": expected_draws}
72+
expected_sizes = {"chain": expected_chains, "draw": expected_draws}
73+
74+
if sample_stats:
75+
sample_stats_ds = idata["sample_stats"]
76+
for name in sample_stats.keys():
77+
assert sample_stats_ds[name].sizes == expected_sizes
78+
79+
else:
80+
assert "sample_stats" not in idata
81+
82+
assert idata.posterior.sizes == expected_sizes
6383

6484

6585
mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)

0 commit comments

Comments
 (0)