Skip to content

Commit 1b0b9cb

Browse files
kylejcaronkylejcaron
andauthored
Infer NumPyro InferenceData dims automatically (#2441)
* from_numpyro infers dims by default * linted and formatted * removed unneeded print statement * caught edge case for inferring numpyro dims, changed failing test to expected behavior * improved documentation * better type hints * updated changelog with az.from_numpyro change * coords no longer required for infer_dims in NumPyroConverter. fixed infer_pred_dims logic --------- Co-authored-by: kylejcaron <[email protected]>
1 parent d4de8b6 commit 1b0b9cb

File tree

3 files changed

+243
-7
lines changed

3 files changed

+243
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Unreleased
44

55
### New features
6+
- Make `arviz.from_numpyro(..., dims=None)` automatically infer dims from the numpyro model based on its numpyro.plate structure
67

78
### Maintenance and fixes
89

arviz/data/io_numpyro.py

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""NumPyro-specific conversion code."""
22

3+
from collections import defaultdict
34
import logging
4-
from typing import Callable, Optional
5+
from typing import Any, Callable, Optional, Dict, List, Tuple
56

67
import numpy as np
78

@@ -13,6 +14,70 @@
1314
_log = logging.getLogger(__name__)
1415

1516

17+
def _add_dims(dims_a: Dict[str, List[str]], dims_b: Dict[str, List[str]]) -> Dict[str, List[str]]:
18+
merged = defaultdict(list)
19+
20+
for k, v in dims_a.items():
21+
merged[k].extend(v)
22+
23+
for k, v in dims_b.items():
24+
merged[k].extend(v)
25+
26+
# Convert back to a regular dict
27+
return dict(merged)
28+
29+
30+
def infer_dims(
31+
model: Callable,
32+
model_args: Optional[Tuple[Any, ...]] = None,
33+
model_kwargs: Optional[Dict[str, Any]] = None,
34+
) -> Dict[str, List[str]]:
35+
36+
from numpyro import handlers, distributions as dist
37+
from numpyro.ops.pytree import PytreeTrace
38+
from numpyro.infer.initialization import init_to_sample
39+
import jax
40+
41+
model_args = tuple() if model_args is None else model_args
42+
model_kwargs = dict() if model_args is None else model_kwargs
43+
44+
def _get_dist_name(fn):
45+
if isinstance(fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)):
46+
return _get_dist_name(fn.base_dist)
47+
return type(fn).__name__
48+
49+
def get_trace():
50+
# We use `init_to_sample` to get around ImproperUniform distribution,
51+
# which does not have `sample` method.
52+
subs_model = handlers.substitute(
53+
handlers.seed(model, 0),
54+
substitute_fn=init_to_sample,
55+
)
56+
trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs)
57+
# Work around an issue where jax.eval_shape does not work
58+
# for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
59+
# Here we will remove `fn` and store its name in the trace.
60+
for _, site in trace.items():
61+
if site["type"] == "sample":
62+
site["fn_name"] = _get_dist_name(site.pop("fn"))
63+
elif site["type"] == "deterministic":
64+
site["fn_name"] = "Deterministic"
65+
return PytreeTrace(trace)
66+
67+
# We use eval_shape to avoid any array computation.
68+
trace = jax.eval_shape(get_trace).trace
69+
70+
named_dims = {}
71+
72+
for name, site in trace.items():
73+
batch_dims = [frame.name for frame in sorted(site["cond_indep_stack"], key=lambda x: x.dim)]
74+
event_dims = list(site.get("infer", {}).get("event_dims", []))
75+
if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims):
76+
named_dims[name] = batch_dims + event_dims
77+
78+
return named_dims
79+
80+
1681
class NumPyroConverter:
1782
"""Encapsulate NumPyro specific logic."""
1883

@@ -36,6 +101,7 @@ def __init__(
36101
coords=None,
37102
dims=None,
38103
pred_dims=None,
104+
extra_event_dims=None,
39105
num_chains=1,
40106
):
41107
"""Convert NumPyro data into an InferenceData object.
@@ -58,9 +124,12 @@ def __init__(
58124
coords : dict[str] -> list[str]
59125
Map of dimensions to coordinates
60126
dims : dict[str] -> list[str]
61-
Map variable names to their coordinates
127+
Map variable names to their coordinates. Will be inferred if they are not provided.
62128
pred_dims: dict
63129
Dims for predictions data. Map variable names to their coordinates.
130+
extra_event_dims: dict
131+
Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to
132+
their coordinates.
64133
num_chains: int
65134
Number of chains used for sampling. Ignored if posterior is present.
66135
"""
@@ -80,6 +149,7 @@ def __init__(
80149
self.coords = coords
81150
self.dims = dims
82151
self.pred_dims = pred_dims
152+
self.extra_event_dims = extra_event_dims
83153
self.numpyro = numpyro
84154

85155
def arbitrary_element(dct):
@@ -107,6 +177,10 @@ def arbitrary_element(dct):
107177
# model arguments and keyword arguments
108178
self._args = self.posterior._args # pylint: disable=protected-access
109179
self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
180+
self.dims = self.dims if self.dims is not None else self.infer_dims()
181+
self.pred_dims = (
182+
self.pred_dims if self.pred_dims is not None else self.infer_pred_dims()
183+
)
110184
else:
111185
self.nchains = num_chains
112186
get_from = None
@@ -325,6 +399,23 @@ def to_inference_data(self):
325399
}
326400
)
327401

402+
@requires("posterior")
403+
@requires("model")
404+
def infer_dims(self) -> Dict[str, List[str]]:
405+
dims = infer_dims(self.model, self._args, self._kwargs)
406+
if self.extra_event_dims:
407+
dims = _add_dims(dims, self.extra_event_dims)
408+
return dims
409+
410+
@requires("posterior")
411+
@requires("model")
412+
@requires("predictions")
413+
def infer_pred_dims(self) -> Dict[str, List[str]]:
414+
dims = infer_dims(self.model, self._args, self._kwargs)
415+
if self.extra_event_dims:
416+
dims = _add_dims(dims, self.extra_event_dims)
417+
return dims
418+
328419

329420
def from_numpyro(
330421
posterior=None,
@@ -339,10 +430,25 @@ def from_numpyro(
339430
coords=None,
340431
dims=None,
341432
pred_dims=None,
433+
extra_event_dims=None,
342434
num_chains=1,
343435
):
344436
"""Convert NumPyro data into an InferenceData object.
345437
438+
If no dims are provided, this will infer batch dim names from NumPyro model plates.
439+
For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}`
440+
can be provided in numpyro.sample, i.e.::
441+
442+
# equivalent to dims entry, {"gamma": ["groups"]}
443+
gamma = numpyro.sample(
444+
"gamma",
445+
dist.ZeroSumNormal(1, event_shape=(n_groups,)),
446+
infer={"event_dims":["groups"]}
447+
)
448+
449+
There is also an additional `extra_event_dims` input to cover any edge cases, for instance
450+
deterministic sites with event dims (which dont have an `infer` argument to provide metadata).
451+
346452
For a usage example read the
347453
:ref:`Creating InferenceData section on from_numpyro <creating_InferenceData>`
348454
@@ -364,9 +470,10 @@ def from_numpyro(
364470
coords : dict[str] -> list[str]
365471
Map of dimensions to coordinates
366472
dims : dict[str] -> list[str]
367-
Map variable names to their coordinates
473+
Map variable names to their coordinates. Will be inferred if they are not provided.
368474
pred_dims: dict
369-
Dims for predictions data. Map variable names to their coordinates.
475+
Dims for predictions data. Map variable names to their coordinates. Default behavior is to
476+
infer dims if this is not provided
370477
num_chains: int
371478
Number of chains used for sampling. Ignored if posterior is present.
372479
"""
@@ -382,5 +489,6 @@ def from_numpyro(
382489
coords=coords,
383490
dims=dims,
384491
pred_dims=pred_dims,
492+
extra_event_dims=extra_event_dims,
385493
num_chains=num_chains,
386494
).to_inference_data()

arviz/tests/external_tests/test_data_numpyro.py

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def predictions_data(self, data, predictions_params):
4646
)
4747
return predictions
4848

49-
def get_inference_data(self, data, eight_schools_params, predictions_data, predictions_params):
49+
def get_inference_data(
50+
self, data, eight_schools_params, predictions_data, predictions_params, infer_dims=False
51+
):
5052
posterior_samples = data.obj.get_samples()
5153
model = data.obj.sampler.model
5254
posterior_predictive = Predictive(model, posterior_samples)(
@@ -55,6 +57,12 @@ def get_inference_data(self, data, eight_schools_params, predictions_data, predi
5557
prior = Predictive(model, num_samples=500)(
5658
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
5759
)
60+
dims = {"theta": ["school"], "eta": ["school"], "obs": ["school"]}
61+
pred_dims = {"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]}
62+
if infer_dims:
63+
dims = None
64+
pred_dims = None
65+
5866
predictions = predictions_data
5967
return from_numpyro(
6068
posterior=data.obj,
@@ -65,8 +73,8 @@ def get_inference_data(self, data, eight_schools_params, predictions_data, predi
6573
"school": np.arange(eight_schools_params["J"]),
6674
"school_pred": np.arange(predictions_params["J"]),
6775
},
68-
dims={"theta": ["school"], "eta": ["school"], "obs": ["school"]},
69-
pred_dims={"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]},
76+
dims=dims,
77+
pred_dims=pred_dims,
7078
)
7179

7280
def test_inference_data_namedtuple(self, data):
@@ -77,6 +85,7 @@ def test_inference_data_namedtuple(self, data):
7785
data.obj.get_samples = lambda *args, **kwargs: data_namedtuple
7886
inference_data = from_numpyro(
7987
posterior=data.obj,
88+
dims={}, # This mock test needs to turn off autodims like so or mock group_by_chain
8089
)
8190
assert isinstance(data.obj.get_samples(), Samples)
8291
data.obj.get_samples = _old_fn
@@ -282,3 +291,121 @@ def model():
282291
mcmc.run(PRNGKey(0))
283292
inference_data = from_numpyro(mcmc)
284293
assert inference_data.observed_data
294+
295+
def test_mcmc_infer_dims(self):
296+
import numpyro
297+
import numpyro.distributions as dist
298+
from numpyro.infer import MCMC, NUTS
299+
300+
def model():
301+
# note: group2 gets assigned dim=-1 and group1 is assigned dim=-2
302+
with numpyro.plate("group2", 5), numpyro.plate("group1", 10):
303+
_ = numpyro.sample("param", dist.Normal(0, 1))
304+
305+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
306+
mcmc.run(PRNGKey(0))
307+
inference_data = from_numpyro(
308+
mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
309+
)
310+
assert inference_data.posterior.param.dims == ("chain", "draw", "group1", "group2")
311+
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
312+
313+
def test_mcmc_infer_unsorted_dims(self):
314+
import numpyro
315+
import numpyro.distributions as dist
316+
from numpyro.infer import MCMC, NUTS
317+
318+
def model():
319+
group1_plate = numpyro.plate("group1", 10, dim=-1)
320+
group2_plate = numpyro.plate("group2", 5, dim=-2)
321+
322+
# the plate contexts are entered in a different order than the pre-defined dims
323+
# we should make sure this still works because the trace has all of the info it needs
324+
with group2_plate, group1_plate:
325+
_ = numpyro.sample("param", dist.Normal(0, 1))
326+
327+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
328+
mcmc.run(PRNGKey(0))
329+
inference_data = from_numpyro(
330+
mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
331+
)
332+
assert inference_data.posterior.param.dims == ("chain", "draw", "group2", "group1")
333+
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
334+
335+
def test_mcmc_infer_dims_no_coords(self):
336+
import numpyro
337+
import numpyro.distributions as dist
338+
from numpyro.infer import MCMC, NUTS
339+
340+
def model():
341+
with numpyro.plate("group", 5):
342+
_ = numpyro.sample("param", dist.Normal(0, 1))
343+
344+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
345+
mcmc.run(PRNGKey(0))
346+
inference_data = from_numpyro(mcmc)
347+
assert inference_data.posterior.param.dims == ("chain", "draw", "group")
348+
349+
def test_mcmc_event_dims(self):
350+
import numpyro
351+
import numpyro.distributions as dist
352+
from numpyro.infer import MCMC, NUTS
353+
354+
def model():
355+
_ = numpyro.sample(
356+
"gamma", dist.ZeroSumNormal(1, event_shape=(10,)), infer={"event_dims": ["groups"]}
357+
)
358+
359+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
360+
mcmc.run(PRNGKey(0))
361+
inference_data = from_numpyro(mcmc, coords={"groups": np.arange(10)})
362+
assert inference_data.posterior.gamma.dims == ("chain", "draw", "groups")
363+
assert "groups" in inference_data.posterior.gamma.coords
364+
365+
@pytest.mark.xfail
366+
def test_mcmc_inferred_dims_univariate(self):
367+
import numpyro
368+
import numpyro.distributions as dist
369+
from numpyro.infer import MCMC, NUTS
370+
import jax.numpy as jnp
371+
372+
def model():
373+
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
374+
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
375+
with numpyro.plate("obs_idx", 3):
376+
# mu is plated by obs_idx, but isnt broadcasted to the plate shape
377+
# the expected behavior is that this should cause a failure
378+
mu = numpyro.deterministic("mu", alpha)
379+
return numpyro.sample("y", dist.Normal(mu, sigma), obs=jnp.array([-1, 0, 1]))
380+
381+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
382+
mcmc.run(PRNGKey(0))
383+
inference_data = from_numpyro(mcmc, coords={"obs_idx": np.arange(3)})
384+
assert inference_data.posterior.mu.dims == ("chain", "draw", "obs_idx")
385+
assert "obs_idx" in inference_data.posterior.mu.coords
386+
387+
def test_mcmc_extra_event_dims(self):
388+
import numpyro
389+
import numpyro.distributions as dist
390+
from numpyro.infer import MCMC, NUTS
391+
392+
def model():
393+
gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(10,)))
394+
_ = numpyro.deterministic("gamma_plus1", gamma + 1)
395+
396+
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
397+
mcmc.run(PRNGKey(0))
398+
inference_data = from_numpyro(
399+
mcmc, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]}
400+
)
401+
assert inference_data.posterior.gamma_plus1.dims == ("chain", "draw", "groups")
402+
assert "groups" in inference_data.posterior.gamma_plus1.coords
403+
404+
def test_mcmc_predictions_infer_dims(
405+
self, data, eight_schools_params, predictions_data, predictions_params
406+
):
407+
inference_data = self.get_inference_data(
408+
data, eight_schools_params, predictions_data, predictions_params, infer_dims=True
409+
)
410+
assert inference_data.predictions.obs.dims == ("chain", "draw", "J")
411+
assert "J" in inference_data.predictions.obs.coords

0 commit comments

Comments
 (0)