Skip to content

Commit 8b3eab9

Browse files
committed
updated sample_dims
1 parent 7a63017 commit 8b3eab9

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

external_tests/test_numpyro.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def guide():
334334

335335
result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
336336
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
337-
sample_dims = ("samples",) if svi else ("chain", "draw")
337+
sample_dims = ("sample",) if svi else ("chain", "draw")
338338

339339
inference_data = from_numpyro_func(
340340
**result, coords={"group1": np.arange(10), "group2": np.arange(5)}
@@ -380,7 +380,7 @@ def guide():
380380

381381
result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
382382
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
383-
sample_dims = ("samples",) if svi else ("chain", "draw")
383+
sample_dims = ("sample",) if svi else ("chain", "draw")
384384

385385
inference_data = from_numpyro_func(
386386
**result, coords={"group1": np.arange(10), "group2": np.arange(5)}
@@ -417,7 +417,7 @@ def guide():
417417

418418
result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
419419
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
420-
sample_dims = ("samples",) if svi else ("chain", "draw")
420+
sample_dims = ("sample",) if svi else ("chain", "draw")
421421

422422
inference_data = from_numpyro_func(**result)
423423
assert inference_data.posterior.param.dims == sample_dims + ("group",)
@@ -453,7 +453,7 @@ def guide():
453453

454454
result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
455455
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
456-
sample_dims = ("samples",) if svi else ("chain", "draw")
456+
sample_dims = ("sample",) if svi else ("chain", "draw")
457457

458458
inference_data = from_numpyro_func(**result, coords={"groups": np.arange(10)})
459459
assert inference_data.posterior.gamma.dims == sample_dims + ("groups",)
@@ -535,7 +535,7 @@ def guide():
535535

536536
result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
537537
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
538-
sample_dims = ("samples",) if svi else ("chain", "draw")
538+
sample_dims = ("sample",) if svi else ("chain", "draw")
539539
inference_data = from_numpyro_func(
540540
**result, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]}
541541
)
@@ -548,7 +548,7 @@ def test_predictions_infer_dims(
548548
inference_data = self.get_inference_data(
549549
data, eight_schools_params, predictions_data, predictions_params, infer_dims=True
550550
)
551-
sample_dims = ("samples",) if isinstance(data.obj, dict) else ("chain", "draw")
551+
sample_dims = ("sample",) if isinstance(data.obj, dict) else ("chain", "draw")
552552
assert inference_data.predictions.obs.dims == (sample_dims + ("J",))
553553
assert "J" in inference_data.predictions.obs.coords
554554

src/arviz_base/io_numpyro.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
self.num_samples = num_samples
3434
self.thinning = 1
3535
self.num_chains = 0
36-
self.sample_dims = ["samples"]
36+
self.sample_dims = ["sample"]
3737
self.kind = "svi"
3838

3939
self.numpyro = numpyro
@@ -703,7 +703,7 @@ def from_numpyro_svi(
703703
model_kwargs=model_kwargs,
704704
num_samples=num_samples,
705705
)
706-
with rc_context(rc={"data.sample_dims": ["samples"]}):
706+
with rc_context(rc={"data.sample_dims": ["sample"]}):
707707
return NumPyroConverter(
708708
posterior=posterior,
709709
prior=prior,

0 commit comments

Comments
 (0)