@@ -334,7 +334,7 @@ def guide():
334334
335335 result = self ._run_inference (model , svi = svi , guide_fn = guide_fn )
336336 from_numpyro_func = from_numpyro_svi if svi else from_numpyro
337- sample_dims = ("samples " ,) if svi else ("chain" , "draw" )
337+ sample_dims = ("sample " ,) if svi else ("chain" , "draw" )
338338
339339 inference_data = from_numpyro_func (
340340 ** result , coords = {"group1" : np .arange (10 ), "group2" : np .arange (5 )}
@@ -380,7 +380,7 @@ def guide():
380380
381381 result = self ._run_inference (model , svi = svi , guide_fn = guide_fn )
382382 from_numpyro_func = from_numpyro_svi if svi else from_numpyro
383- sample_dims = ("samples " ,) if svi else ("chain" , "draw" )
383+ sample_dims = ("sample " ,) if svi else ("chain" , "draw" )
384384
385385 inference_data = from_numpyro_func (
386386 ** result , coords = {"group1" : np .arange (10 ), "group2" : np .arange (5 )}
@@ -417,7 +417,7 @@ def guide():
417417
418418 result = self ._run_inference (model , svi = svi , guide_fn = guide_fn )
419419 from_numpyro_func = from_numpyro_svi if svi else from_numpyro
420- sample_dims = ("samples " ,) if svi else ("chain" , "draw" )
420+ sample_dims = ("sample " ,) if svi else ("chain" , "draw" )
421421
422422 inference_data = from_numpyro_func (** result )
423423 assert inference_data .posterior .param .dims == sample_dims + ("group" ,)
@@ -453,7 +453,7 @@ def guide():
453453
454454 result = self ._run_inference (model , svi = svi , guide_fn = guide_fn )
455455 from_numpyro_func = from_numpyro_svi if svi else from_numpyro
456- sample_dims = ("samples " ,) if svi else ("chain" , "draw" )
456+ sample_dims = ("sample " ,) if svi else ("chain" , "draw" )
457457
458458 inference_data = from_numpyro_func (** result , coords = {"groups" : np .arange (10 )})
459459 assert inference_data .posterior .gamma .dims == sample_dims + ("groups" ,)
@@ -535,7 +535,7 @@ def guide():
535535
536536 result = self ._run_inference (model , svi = svi , guide_fn = guide_fn )
537537 from_numpyro_func = from_numpyro_svi if svi else from_numpyro
538- sample_dims = ("samples " ,) if svi else ("chain" , "draw" )
538+ sample_dims = ("sample " ,) if svi else ("chain" , "draw" )
539539 inference_data = from_numpyro_func (
540540 ** result , coords = {"groups" : np .arange (10 )}, extra_event_dims = {"gamma_plus1" : ["groups" ]}
541541 )
@@ -548,7 +548,7 @@ def test_predictions_infer_dims(
548548 inference_data = self .get_inference_data (
549549 data , eight_schools_params , predictions_data , predictions_params , infer_dims = True
550550 )
551- sample_dims = ("samples " ,) if isinstance (data .obj , dict ) else ("chain" , "draw" )
551+ sample_dims = ("sample " ,) if isinstance (data .obj , dict ) else ("chain" , "draw" )
552552 assert inference_data .predictions .obs .dims == (sample_dims + ("J" ,))
553553 assert "J" in inference_data .predictions .obs .coords
554554
0 commit comments