Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 3 additions & 19 deletions examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@

import numpyro
from numpyro import handlers
from numpyro.contrib.funsor import config_enumerate, infer_discrete
from numpyro.contrib.indexing import Vindex
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam


Expand Down Expand Up @@ -313,24 +312,9 @@ def main(args):
mcmc.run(random.PRNGKey(0), *data)
mcmc.print_summary()

def infer_discrete_model(rng_key, samples):
conditioned_model = handlers.condition(model, data=samples)
infer_discrete_model = infer_discrete(
config_enumerate(conditioned_model), rng_key=rng_key
)
with handlers.trace() as tr:
infer_discrete_model(*data)

return {
name: site["value"]
for name, site in tr.items()
if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel"
}

posterior_samples = mcmc.get_samples()
discrete_samples = vmap(infer_discrete_model)(
random.split(random.PRNGKey(1), args.num_samples), posterior_samples
)
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *data)

item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)(
discrete_samples["c"].squeeze(-1)
Expand Down
15 changes: 10 additions & 5 deletions numpyro/contrib/funsor/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def _sample_posterior(
values = [v.reshape((-1,) + prototype_shape[1:]) for v in values]
data[root_name] = jnp.concatenate(values)

with substitute(data=data):
return model(*args, **kwargs)
return data
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revise this a bit to use this function in Predictive (and avoid tracing the model 2 times)



def infer_discrete(fn=None, first_available_dim=None, temperature=1, rng_key=None):
Expand Down Expand Up @@ -169,6 +168,12 @@ def viterbi_decoder(data, hidden_dim=10):
temperature=temperature,
rng_key=rng_key,
)
return functools.partial(
_sample_posterior, fn, first_available_dim, temperature, rng_key
)

def wrap_fn(*args, **kwargs):
samples = _sample_posterior(
fn, first_available_dim, temperature, rng_key, *args, **kwargs
)
with substitute(data=samples):
return fn(*args, **kwargs)

return wrap_fn
55 changes: 47 additions & 8 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from jax import device_get, jacfwd, lax, random, value_and_grad
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.tree_util import tree_map

import numpyro
from numpyro.distributions import constraints
from numpyro.distributions.transforms import biject_to
from numpyro.distributions.util import is_identically_one, sum_rightmost
from numpyro.handlers import replay, seed, substitute, trace
from numpyro.handlers import condition, replay, seed, substitute, trace
from numpyro.infer.initialization import init_to_uniform, init_to_value
from numpyro.util import not_jax_tracer, soft_vmap, while_loop

Expand Down Expand Up @@ -673,17 +674,47 @@ def _predictive(
posterior_samples,
batch_shape,
return_sites=None,
infer_discrete=False,
parallel=True,
model_args=(),
model_kwargs={},
):
model = numpyro.handlers.mask(model, mask=False)
masked_model = numpyro.handlers.mask(model, mask=False)
if infer_discrete:
# inspect the model to get some structure
rng_key, subkey = random.split(rng_key)
batch_ndim = len(batch_shape)
prototype_sample = tree_map(
lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[batch_ndim:])[0],
posterior_samples,
)
prototype_trace = trace(
seed(substitute(masked_model, prototype_sample), subkey)
).get_trace(*model_args, **model_kwargs)
first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1

def single_prediction(val):
rng_key, samples = val
model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(
*model_args, **model_kwargs
)
if infer_discrete:
from numpyro.contrib.funsor import config_enumerate
from numpyro.contrib.funsor.discrete import _sample_posterior

model_trace = prototype_trace
temperature = 1
pred_samples = _sample_posterior(
config_enumerate(condition(model, samples)),
Copy link
Member

@fritzo fritzo Jul 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised you're automatically configuring the model for enumeration here. In Pyro we let the user decide which variables are enumerated (though maybe this is too burdensome). That way a guide might sample some discrete latents, and the model might enumerate others. I would expect Predictive to use the guide's samples if provided, and only enumerate sites already marked for enumeration by the user.

I guess if in NumPyro you automatically wrap models with @config_enumerate inside MCMC, then it would also make sense to automatically wrap them here. Still, it would be nice to support SVI with guides that sample some or all discrete latent variables.

Copy link
Member Author

@fehiepsi fehiepsi Jul 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let the user decide which variables are enumerated

We'll aim for this except for algorithms that do not work with discrete latent sites (in this case, it makes sense to enable enumeration by default and raise errors for invalid models). For example, currently, DiscreteHMCGibbs will perform Gibbs update for discrete latent sites that are not marked "enumerated". Then posterior samples will include all latent variables except for those marked "enumerated".

When infer_discrete=True, I assumed that the latent sites that are not available in the posterior_samples or guide are enumerated (those latent sites belong to the samples variable in the above code). So it makes sense to me to config_enumerate them by default (config_enumerate will skip observed sites, including those sites in samples). What do you think? (it seems to not contradict with the usage case in your comment)

first_available_dim,
temperature,
rng_key,
*model_args,
**model_kwargs,
)
else:
model_trace = trace(
seed(substitute(masked_model, samples), rng_key)
).get_trace(*model_args, **model_kwargs)
pred_samples = {name: site["value"] for name, site in model_trace.items()}

if return_sites is not None:
if return_sites == "":
sites = {
Expand All @@ -698,9 +729,7 @@ def single_prediction(val):
if (site["type"] == "sample" and k not in samples)
or (site["type"] == "deterministic")
}
return {
name: site["value"] for name, site in model_trace.items() if name in sites
}
return {name: value for name, value in pred_samples.items() if name in sites}

num_samples = int(np.prod(batch_shape))
if num_samples > 1:
Expand Down Expand Up @@ -729,6 +758,12 @@ class Predictive(object):
:param int num_samples: number of samples
:param list return_sites: sites to return; by default only sample sites not present
in `posterior_samples` are returned.
:param bool infer_discrete: whether or not to sample discrete sites from the
posterior, conditioned on observations and other latent values in
``posterior_samples``. Under the hood, those sites will be marked with
``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at
the `Pyro enumeration tutorial <https://pyro.ai/examples/enumeration.html>`_.
Note that this requires ``funsor`` installation.
:param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`.
Defaults to False.
:param batch_ndims: the number of batch dimensions in posterior samples. Some usages:
Expand All @@ -749,10 +784,12 @@ def __init__(
self,
model,
posterior_samples=None,
*,
guide=None,
params=None,
num_samples=None,
return_sites=None,
infer_discrete=False,
parallel=False,
batch_ndims=1,
):
Expand Down Expand Up @@ -801,6 +838,7 @@ def __init__(
self.num_samples = num_samples
self.guide = guide
self.params = {} if params is None else params
self.infer_discrete = infer_discrete
self.return_sites = return_sites
self.parallel = parallel
self.batch_ndims = batch_ndims
Expand Down Expand Up @@ -838,6 +876,7 @@ def __call__(self, rng_key, *args, **kwargs):
posterior_samples,
self._batch_shape,
return_sites=self.return_sites,
infer_discrete=self.infer_discrete,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
Expand Down