-
Notifications
You must be signed in to change notification settings - Fork 270
Support infer_discrete for Predictive #1086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,12 +11,13 @@ | |
| from jax import device_get, jacfwd, lax, random, value_and_grad | ||
| from jax.flatten_util import ravel_pytree | ||
| import jax.numpy as jnp | ||
| from jax.tree_util import tree_map | ||
|
|
||
| import numpyro | ||
| from numpyro.distributions import constraints | ||
| from numpyro.distributions.transforms import biject_to | ||
| from numpyro.distributions.util import is_identically_one, sum_rightmost | ||
| from numpyro.handlers import replay, seed, substitute, trace | ||
| from numpyro.handlers import condition, replay, seed, substitute, trace | ||
| from numpyro.infer.initialization import init_to_uniform, init_to_value | ||
| from numpyro.util import not_jax_tracer, soft_vmap, while_loop | ||
|
|
||
|
|
@@ -673,17 +674,47 @@ def _predictive( | |
| posterior_samples, | ||
| batch_shape, | ||
| return_sites=None, | ||
| infer_discrete=False, | ||
| temperature=1, | ||
fehiepsi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| parallel=True, | ||
| model_args=(), | ||
| model_kwargs={}, | ||
| ): | ||
| model = numpyro.handlers.mask(model, mask=False) | ||
| masked_model = numpyro.handlers.mask(model, mask=False) | ||
| if infer_discrete: | ||
| # inspect the model to get some structure | ||
| rng_key, subkey = random.split(rng_key) | ||
| batch_ndim = len(batch_shape) | ||
| prototype_sample = tree_map( | ||
| lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[batch_ndim:])[0], | ||
| posterior_samples, | ||
| ) | ||
| prototype_trace = trace( | ||
| seed(substitute(masked_model, prototype_sample), subkey) | ||
| ).get_trace(*model_args, **model_kwargs) | ||
| first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1 | ||
|
|
||
| def single_prediction(val): | ||
| rng_key, samples = val | ||
| model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace( | ||
| *model_args, **model_kwargs | ||
| ) | ||
| if infer_discrete: | ||
| from numpyro.contrib.funsor import config_enumerate | ||
| from numpyro.contrib.funsor.discrete import _sample_posterior | ||
|
|
||
| model_trace = prototype_trace | ||
| pred_samples = _sample_posterior( | ||
| config_enumerate(condition(model, samples)), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm surprised you're automatically configuring the model for enumeration here. In Pyro we let the user decide which variables are enumerated (though maybe this is too burdensome). That way a guide might sample some discrete latents, and the model might enumerate others. I would expect I guess if in NumPyro you automatically wrap models with
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We'll aim for this except for algorithms that do not work with discrete latent sites (in this case, it makes sense to enable enumeration by default and raise errors for invalid models). For example, currently, DiscreteHMCGibbs will perform Gibbs update for discrete latent sites that are not marked "enumerated". Then posterior samples will include all latent variables except for those marked "enumerated". When |
||
| first_available_dim, | ||
| temperature, | ||
| rng_key, | ||
| *model_args, | ||
| **model_kwargs, | ||
| ) | ||
| else: | ||
| model_trace = trace( | ||
| seed(substitute(masked_model, samples), rng_key) | ||
| ).get_trace(*model_args, **model_kwargs) | ||
| pred_samples = {name: site["value"] for name, site in model_trace.items()} | ||
|
|
||
| if return_sites is not None: | ||
| if return_sites == "": | ||
| sites = { | ||
|
|
@@ -698,9 +729,7 @@ def single_prediction(val): | |
| if (site["type"] == "sample" and k not in samples) | ||
| or (site["type"] == "deterministic") | ||
| } | ||
| return { | ||
| name: site["value"] for name, site in model_trace.items() if name in sites | ||
| } | ||
| return {name: value for name, value in pred_samples.items() if name in sites} | ||
|
|
||
| num_samples = int(np.prod(batch_shape)) | ||
| if num_samples > 1: | ||
|
|
@@ -729,6 +758,15 @@ class Predictive(object): | |
| :param int num_samples: number of samples | ||
| :param list return_sites: sites to return; by default only sample sites not present | ||
| in `posterior_samples` are returned. | ||
| :param bool infer_discrete: whether or not to sample discrete sites from the | ||
| posterior, conditioned on observations and other latent values in | ||
| ``posterior_samples``. Under the hood, those sites will be marked with | ||
| ``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at | ||
| the `Pyro enumeration tutorial <https://pyro.ai/examples/enumeration.html>`_. | ||
| This feature requires ``funsor`` installation. | ||
| :param int temperature: Either 1 (sample via forward-filter backward-sample) | ||
| or 0 (optimize via Viterbi-like MAP inference). Defaults to 1 (sample). | ||
| This argument only takes effect when ``infer_discrete=True``. | ||
| :param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`. | ||
| Defaults to False. | ||
| :param batch_ndims: the number of batch dimensions in posterior samples. Some usages: | ||
|
|
@@ -749,10 +787,13 @@ def __init__( | |
| self, | ||
| model, | ||
| posterior_samples=None, | ||
| *, | ||
| guide=None, | ||
| params=None, | ||
| num_samples=None, | ||
| return_sites=None, | ||
| infer_discrete=False, | ||
| temperature=1, | ||
| parallel=False, | ||
| batch_ndims=1, | ||
| ): | ||
|
|
@@ -801,6 +842,8 @@ def __init__( | |
| self.num_samples = num_samples | ||
| self.guide = guide | ||
| self.params = {} if params is None else params | ||
| self.infer_discrete = infer_discrete | ||
| self.temperature = temperature | ||
| self.return_sites = return_sites | ||
| self.parallel = parallel | ||
| self.batch_ndims = batch_ndims | ||
|
|
@@ -838,6 +881,8 @@ def __call__(self, rng_key, *args, **kwargs): | |
| posterior_samples, | ||
| self._batch_shape, | ||
| return_sites=self.return_sites, | ||
| infer_discrete=self.infer_discrete, | ||
| temperature=self.temperature, | ||
| parallel=self.parallel, | ||
| model_args=args, | ||
| model_kwargs=kwargs, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I revise this a bit to use this function in Predictive (and avoid tracing the model 2 times)