@@ -38,14 +38,16 @@ def test_domain(values, edges, expectation):
3838
3939
4040@pytest .mark .parametrize (
41- "args, kwargs, expected_draws " ,
41+ "args, kwargs, expected_size " ,
4242 [
43- pytest .param ((), {}, 10 , id = "default" ),
44- pytest .param ((100 ,), {}, 100 , id = "positional-draws" ),
45- pytest .param ((), {"draws" : 100 }, 100 , id = "keyword-draws" ),
43+ pytest .param ((), {}, (1 , 10 ), id = "default" ),
44+ pytest .param ((100 ,), {}, (1 , 100 ), id = "positional-draws" ),
45+ pytest .param ((), {"draws" : 100 }, (1 , 100 ), id = "keyword-draws" ),
46+ pytest .param ((100 ,), {"chains" : 6 }, (6 , 100 ), id = "chains" ),
4647 ],
4748)
48- def test_mock_sample (args , kwargs , expected_draws ) -> None :
49+ def test_mock_sample (args , kwargs , expected_size ) -> None :
50+ expected_chains , expected_draws = expected_size
4951 _ , model , _ = simple_normal (bounded_prior = True )
5052
5153 with model :
@@ -57,7 +59,7 @@ def test_mock_sample(args, kwargs, expected_draws) -> None:
5759 assert "posterior_predictive" not in idata
5860 assert "sample_stats" not in idata
5961
60- assert idata .posterior .sizes == {"chain" : 1 , "draw" : expected_draws }
62+ assert idata .posterior .sizes == {"chain" : expected_chains , "draw" : expected_draws }
6163
6264
6365mock_pymc_sample = pytest .fixture (scope = "function" )(mock_sample_setup_and_teardown )
0 commit comments