Skip to content

Commit 4445593

Browse files
Numpyro MCMC autodims migration (#88)
* migrated numpyro mcmc dim inferring * added comments for infer_dims * updated io_numpyro for docstub conventions * further improved docstub conventions for io_numpyro * added indentation in from_numpyro docstring to fix sphinx failure * improved io_numpyro test case for test_mcmc_inferred_dims_univariate * Update src/arviz_base/io_numpyro.py Fixed typo with model_kwargs Co-authored-by: Oriol Abril-Pla <[email protected]> * updated model_args variable length tuple type hint in io_numpyro --------- Co-authored-by: Oriol Abril-Pla <[email protected]>
1 parent fa9e1b5 commit 4445593

File tree

3 files changed

+305
-11
lines changed

3 files changed

+305
-11
lines changed

external_tests/test_numpyro.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def predictions_data(self, data, predictions_params):
4343
)
4444
return predictions
4545

46-
def get_inference_data(self, data, eight_schools_params, predictions_data, predictions_params):
46+
def get_inference_data(
47+
self, data, eight_schools_params, predictions_data, predictions_params, infer_dims=False
48+
):
4749
posterior_samples = data.obj.get_samples()
4850
model = data.obj.sampler.model
4951
posterior_predictive = Predictive(model, posterior_samples)(
@@ -52,6 +54,11 @@ def get_inference_data(self, data, eight_schools_params, predictions_data, predi
5254
prior = Predictive(model, num_samples=500)(
5355
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
5456
)
57+
dims = {"theta": ["school"], "eta": ["school"], "obs": ["school"]}
58+
pred_dims = {"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]}
59+
if infer_dims:
60+
dims = pred_dims = None
61+
5562
predictions = predictions_data
5663
return from_numpyro(
5764
posterior=data.obj,
@@ -62,8 +69,8 @@ def get_inference_data(self, data, eight_schools_params, predictions_data, predi
6269
"school": np.arange(eight_schools_params["J"]),
6370
"school_pred": np.arange(predictions_params["J"]),
6471
},
65-
dims={"theta": ["school"], "eta": ["school"], "obs": ["school"]},
66-
pred_dims={"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]},
72+
dims=dims,
73+
pred_dims=pred_dims,
6774
)
6875

6976
def test_inference_data_namedtuple(self, data):
@@ -74,6 +81,7 @@ def test_inference_data_namedtuple(self, data):
7481
data.obj.get_samples = lambda *args, **kwargs: data_namedtuple
7582
inference_data = from_numpyro(
7683
posterior=data.obj,
84+
dims={}, # This mock test needs to turn off autodims like so or mock group_by_chain
7785
)
7886
assert isinstance(data.obj.get_samples(), Samples)
7987
data.obj.get_samples = _old_fn
@@ -273,3 +281,119 @@ def model():
273281
mcmc.run(PRNGKey(0))
274282
inference_data = from_numpyro(mcmc)
275283
assert inference_data.observed_data
284+
285+
def test_mcmc_infer_dims(self):
286+
import numpyro
287+
import numpyro.distributions as dist
288+
from numpyro.infer import MCMC, NUTS
289+
290+
def model():
291+
# note: group2 gets assigned dim=-1 and group1 is assigned dim=-2
292+
with numpyro.plate("group2", 5), numpyro.plate("group1", 10):
293+
_ = numpyro.sample("param", dist.Normal(0, 1))
294+
295+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
296+
mcmc.run(PRNGKey(0))
297+
inference_data = from_numpyro(
298+
mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
299+
)
300+
assert inference_data.posterior.param.dims == ("chain", "draw", "group1", "group2")
301+
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
302+
303+
def test_mcmc_infer_unsorted_dims(self):
304+
import numpyro
305+
import numpyro.distributions as dist
306+
from numpyro.infer import MCMC, NUTS
307+
308+
def model():
309+
group1_plate = numpyro.plate("group1", 10, dim=-1)
310+
group2_plate = numpyro.plate("group2", 5, dim=-2)
311+
312+
# the plate contexts are entered in a different order than the pre-defined dims
313+
# we should make sure this still works because the trace has all of the info it needs
314+
with group2_plate, group1_plate:
315+
_ = numpyro.sample("param", dist.Normal(0, 1))
316+
317+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
318+
mcmc.run(PRNGKey(0))
319+
inference_data = from_numpyro(
320+
mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
321+
)
322+
assert inference_data.posterior.param.dims == ("chain", "draw", "group2", "group1")
323+
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
324+
325+
def test_mcmc_infer_dims_no_coords(self):
326+
import numpyro
327+
import numpyro.distributions as dist
328+
from numpyro.infer import MCMC, NUTS
329+
330+
def model():
331+
with numpyro.plate("group", 5):
332+
_ = numpyro.sample("param", dist.Normal(0, 1))
333+
334+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
335+
mcmc.run(PRNGKey(0))
336+
inference_data = from_numpyro(mcmc)
337+
assert inference_data.posterior.param.dims == ("chain", "draw", "group")
338+
339+
def test_mcmc_event_dims(self):
340+
import numpyro
341+
import numpyro.distributions as dist
342+
from numpyro.infer import MCMC, NUTS
343+
344+
def model():
345+
_ = numpyro.sample(
346+
"gamma", dist.ZeroSumNormal(1, event_shape=(10,)), infer={"event_dims": ["groups"]}
347+
)
348+
349+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
350+
mcmc.run(PRNGKey(0))
351+
inference_data = from_numpyro(mcmc, coords={"groups": np.arange(10)})
352+
assert inference_data.posterior.gamma.dims == ("chain", "draw", "groups")
353+
assert "groups" in inference_data.posterior.gamma.coords
354+
355+
def test_mcmc_inferred_dims_univariate(self):
356+
import jax.numpy as jnp
357+
import numpyro
358+
import numpyro.distributions as dist
359+
from numpyro.infer import MCMC, NUTS
360+
361+
def model():
362+
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
363+
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
364+
with numpyro.plate("obs_idx", 3):
365+
# mu is plated by obs_idx, but isnt broadcasted to the plate shape
366+
# the expected behavior is that this should cause a failure
367+
mu = numpyro.deterministic("mu", alpha)
368+
return numpyro.sample("y", dist.Normal(mu, sigma), obs=jnp.array([-1, 0, 1]))
369+
370+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
371+
mcmc.run(PRNGKey(0))
372+
with pytest.raises(ValueError):
373+
from_numpyro(mcmc, coords={"obs_idx": np.arange(3)})
374+
375+
def test_mcmc_extra_event_dims(self):
376+
import numpyro
377+
import numpyro.distributions as dist
378+
from numpyro.infer import MCMC, NUTS
379+
380+
def model():
381+
gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(10,)))
382+
_ = numpyro.deterministic("gamma_plus1", gamma + 1)
383+
384+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
385+
mcmc.run(PRNGKey(0))
386+
inference_data = from_numpyro(
387+
mcmc, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]}
388+
)
389+
assert inference_data.posterior.gamma_plus1.dims == ("chain", "draw", "groups")
390+
assert "groups" in inference_data.posterior.gamma_plus1.coords
391+
392+
def test_mcmc_predictions_infer_dims(
393+
self, data, eight_schools_params, predictions_data, predictions_params
394+
):
395+
inference_data = self.get_inference_data(
396+
data, eight_schools_params, predictions_data, predictions_params, infer_dims=True
397+
)
398+
assert inference_data.predictions.obs.dims == ("chain", "draw", "J")
399+
assert "J" in inference_data.predictions.obs.coords

0 commit comments

Comments
 (0)