1313# limitations under the License.
1414from contextlib import ExitStack as does_not_raise
1515
16+ import numpy as np
1617import pytest
1718
1819import pymc as pm
@@ -38,28 +39,47 @@ def test_domain(values, edges, expectation):
3839
3940
4041@pytest .mark .parametrize (
41- "args, kwargs, expected_size" ,
42+ "args, kwargs, expected_size, sample_stats " ,
4243 [
43- pytest .param ((), {}, (1 , 10 ), id = "default" ),
44- pytest .param ((100 ,), {}, (1 , 100 ), id = "positional-draws" ),
45- pytest .param ((), {"draws" : 100 }, (1 , 100 ), id = "keyword-draws" ),
46- pytest .param ((100 ,), {"chains" : 6 }, (6 , 100 ), id = "chains" ),
44+ pytest .param ((), {}, (1 , 10 ), None , id = "default" ),
45+ pytest .param ((100 ,), {}, (1 , 100 ), None , id = "positional-draws" ),
46+ pytest .param ((), {"draws" : 100 }, (1 , 100 ), None , id = "keyword-draws" ),
47+ pytest .param ((100 ,), {"chains" : 6 }, (6 , 100 ), None , id = "chains" ),
48+ pytest .param (
49+ (100 ,),
50+ {"chains" : 6 },
51+ (6 , 100 ),
52+ {
53+ "diverging" : np .zeros ,
54+ "tree_depth" : lambda size : np .random .choice (range (2 , 10 ), size = size ),
55+ },
56+ id = "with_sample_stats" ,
57+ ),
4758 ],
4859)
49- def test_mock_sample (args , kwargs , expected_size ) -> None :
60+ def test_mock_sample (args , kwargs , expected_size , sample_stats ) -> None :
5061 expected_chains , expected_draws = expected_size
5162 _ , model , _ = simple_normal (bounded_prior = True )
5263
5364 with model :
54- idata = mock_sample (* args , ** kwargs )
65+ idata = mock_sample (* args , ** kwargs , sample_stats = sample_stats )
5566
5667 assert "posterior" in idata
5768 assert "observed_data" in idata
5869 assert "prior" not in idata
5970 assert "posterior_predictive" not in idata
60- assert "sample_stats" not in idata
6171
62- assert idata .posterior .sizes == {"chain" : expected_chains , "draw" : expected_draws }
72+ expected_sizes = {"chain" : expected_chains , "draw" : expected_draws }
73+
74+ if sample_stats :
75+ sample_stats_ds = idata ["sample_stats" ]
76+ for name in sample_stats .keys ():
77+ assert sample_stats_ds [name ].sizes == expected_sizes
78+
79+ else :
80+ assert "sample_stats" not in idata
81+
82+ assert idata .posterior .sizes == expected_sizes
6383
6484
6585mock_pymc_sample = pytest .fixture (scope = "function" )(mock_sample_setup_and_teardown )
0 commit comments