@@ -277,11 +277,12 @@ def test__get_chains(self):
277
277
assert len (chain ) == 1
278
278
pass
279
279
280
- def test__to_inferencedata (self ):
280
+ @pytest .mark .parametrize ("tstatname" , ["tune" , "sampler__tune" , "nottune" ])
281
+ def test__to_inferencedata (self , tstatname , caplog ):
281
282
rmeta = make_runmeta (
282
283
flexibility = False ,
283
284
sample_stats = [
284
- Variable ("tune" , "bool" ),
285
+ Variable (tstatname , "bool" ),
285
286
Variable ("sampler_0__logp" , "float32" ),
286
287
Variable ("warning" , "str" ),
287
288
],
@@ -294,15 +295,22 @@ def test__to_inferencedata(self):
294
295
draws = [make_draw (rmeta .variables ) for _ in range (n )]
295
296
stats = [make_draw (rmeta .sample_stats ) for _ in range (n )]
296
297
for i , (d , s ) in enumerate (zip (draws , stats )):
297
- s ["tune" ] = i < 4
298
+ s [tstatname ] = i < 4
298
299
chain .append (d , s )
299
300
300
301
idata = run .to_inferencedata ()
301
302
assert isinstance (idata , arviz .InferenceData )
302
303
assert idata .warmup_posterior .dims ["chain" ] == 1
303
- assert idata .warmup_posterior .dims ["draw" ] == 4
304
304
assert idata .posterior .dims ["chain" ] == 1
305
- assert idata .posterior .dims ["draw" ] == 6
305
+ if tstatname == "nottune" :
306
+ # Splitting into warmup/posterior requires a tune stat!
307
+ assert any ("No 'tune' stat" in r .message for r in caplog .records )
308
+ assert idata .warmup_posterior .dims ["draw" ] == 0
309
+ assert idata .posterior .dims ["draw" ] == 10
310
+ else :
311
+ assert idata .warmup_posterior .dims ["draw" ] == 4
312
+ assert idata .posterior .dims ["draw" ] == 6
313
+
306
314
for var in rmeta .variables :
307
315
assert var .name in set (idata .posterior .keys ())
308
316
for svar in rmeta .sample_stats :
0 commit comments