Skip to content

Commit 003424b

Browse files
authored
Support infer_discrete for Predictive (#1086)
* support infer_discrete for Predictive * revise docs * use infer_discrete_temperature * use temperature=1 by default
1 parent 1b517b0 commit 003424b

File tree

3 files changed

+60
-32
lines changed

3 files changed

+60
-32
lines changed

examples/annotation.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,9 @@
4242

4343
import numpyro
4444
from numpyro import handlers
45-
from numpyro.contrib.funsor import config_enumerate, infer_discrete
4645
from numpyro.contrib.indexing import Vindex
4746
import numpyro.distributions as dist
48-
from numpyro.infer import MCMC, NUTS
47+
from numpyro.infer import MCMC, NUTS, Predictive
4948
from numpyro.infer.reparam import LocScaleReparam
5049

5150

@@ -313,24 +312,9 @@ def main(args):
313312
mcmc.run(random.PRNGKey(0), *data)
314313
mcmc.print_summary()
315314

316-
def infer_discrete_model(rng_key, samples):
317-
conditioned_model = handlers.condition(model, data=samples)
318-
infer_discrete_model = infer_discrete(
319-
config_enumerate(conditioned_model), rng_key=rng_key
320-
)
321-
with handlers.trace() as tr:
322-
infer_discrete_model(*data)
323-
324-
return {
325-
name: site["value"]
326-
for name, site in tr.items()
327-
if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel"
328-
}
329-
330315
posterior_samples = mcmc.get_samples()
331-
discrete_samples = vmap(infer_discrete_model)(
332-
random.split(random.PRNGKey(1), args.num_samples), posterior_samples
333-
)
316+
predictive = Predictive(model, posterior_samples, infer_discrete=True)
317+
discrete_samples = predictive(random.PRNGKey(1), *data)
334318

335319
item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)(
336320
discrete_samples["c"].squeeze(-1)

numpyro/contrib/funsor/discrete.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ def _sample_posterior(
118118
values = [v.reshape((-1,) + prototype_shape[1:]) for v in values]
119119
data[root_name] = jnp.concatenate(values)
120120

121-
with substitute(data=data):
122-
return model(*args, **kwargs)
121+
return data
123122

124123

125124
def infer_discrete(fn=None, first_available_dim=None, temperature=1, rng_key=None):
@@ -169,6 +168,12 @@ def viterbi_decoder(data, hidden_dim=10):
169168
temperature=temperature,
170169
rng_key=rng_key,
171170
)
172-
return functools.partial(
173-
_sample_posterior, fn, first_available_dim, temperature, rng_key
174-
)
171+
172+
def wrap_fn(*args, **kwargs):
173+
samples = _sample_posterior(
174+
fn, first_available_dim, temperature, rng_key, *args, **kwargs
175+
)
176+
with substitute(data=samples):
177+
return fn(*args, **kwargs)
178+
179+
return wrap_fn

numpyro/infer/util.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
from jax import device_get, jacfwd, lax, random, value_and_grad
1212
from jax.flatten_util import ravel_pytree
1313
import jax.numpy as jnp
14+
from jax.tree_util import tree_map
1415

1516
import numpyro
1617
from numpyro.distributions import constraints
1718
from numpyro.distributions.transforms import biject_to
1819
from numpyro.distributions.util import is_identically_one, sum_rightmost
19-
from numpyro.handlers import replay, seed, substitute, trace
20+
from numpyro.handlers import condition, replay, seed, substitute, trace
2021
from numpyro.infer.initialization import init_to_uniform, init_to_value
2122
from numpyro.util import not_jax_tracer, soft_vmap, while_loop
2223

@@ -673,17 +674,47 @@ def _predictive(
673674
posterior_samples,
674675
batch_shape,
675676
return_sites=None,
677+
infer_discrete=False,
676678
parallel=True,
677679
model_args=(),
678680
model_kwargs={},
679681
):
680-
model = numpyro.handlers.mask(model, mask=False)
682+
masked_model = numpyro.handlers.mask(model, mask=False)
683+
if infer_discrete:
684+
# inspect the model to get some structure
685+
rng_key, subkey = random.split(rng_key)
686+
batch_ndim = len(batch_shape)
687+
prototype_sample = tree_map(
688+
lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[batch_ndim:])[0],
689+
posterior_samples,
690+
)
691+
prototype_trace = trace(
692+
seed(substitute(masked_model, prototype_sample), subkey)
693+
).get_trace(*model_args, **model_kwargs)
694+
first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1
681695

682696
def single_prediction(val):
683697
rng_key, samples = val
684-
model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(
685-
*model_args, **model_kwargs
686-
)
698+
if infer_discrete:
699+
from numpyro.contrib.funsor import config_enumerate
700+
from numpyro.contrib.funsor.discrete import _sample_posterior
701+
702+
model_trace = prototype_trace
703+
temperature = 1
704+
pred_samples = _sample_posterior(
705+
config_enumerate(condition(model, samples)),
706+
first_available_dim,
707+
temperature,
708+
rng_key,
709+
*model_args,
710+
**model_kwargs,
711+
)
712+
else:
713+
model_trace = trace(
714+
seed(substitute(masked_model, samples), rng_key)
715+
).get_trace(*model_args, **model_kwargs)
716+
pred_samples = {name: site["value"] for name, site in model_trace.items()}
717+
687718
if return_sites is not None:
688719
if return_sites == "":
689720
sites = {
@@ -698,9 +729,7 @@ def single_prediction(val):
698729
if (site["type"] == "sample" and k not in samples)
699730
or (site["type"] == "deterministic")
700731
}
701-
return {
702-
name: site["value"] for name, site in model_trace.items() if name in sites
703-
}
732+
return {name: value for name, value in pred_samples.items() if name in sites}
704733

705734
num_samples = int(np.prod(batch_shape))
706735
if num_samples > 1:
@@ -729,6 +758,12 @@ class Predictive(object):
729758
:param int num_samples: number of samples
730759
:param list return_sites: sites to return; by default only sample sites not present
731760
in `posterior_samples` are returned.
761+
:param bool infer_discrete: whether or not to sample discrete sites from the
762+
posterior, conditioned on observations and other latent values in
763+
``posterior_samples``. Under the hood, those sites will be marked with
764+
``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at
765+
the `Pyro enumeration tutorial <https://pyro.ai/examples/enumeration.html>`_.
766+
Note that this requires ``funsor`` installation.
732767
:param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`.
733768
Defaults to False.
734769
:param batch_ndims: the number of batch dimensions in posterior samples. Some usages:
@@ -749,10 +784,12 @@ def __init__(
749784
self,
750785
model,
751786
posterior_samples=None,
787+
*,
752788
guide=None,
753789
params=None,
754790
num_samples=None,
755791
return_sites=None,
792+
infer_discrete=False,
756793
parallel=False,
757794
batch_ndims=1,
758795
):
@@ -801,6 +838,7 @@ def __init__(
801838
self.num_samples = num_samples
802839
self.guide = guide
803840
self.params = {} if params is None else params
841+
self.infer_discrete = infer_discrete
804842
self.return_sites = return_sites
805843
self.parallel = parallel
806844
self.batch_ndims = batch_ndims
@@ -838,6 +876,7 @@ def __call__(self, rng_key, *args, **kwargs):
838876
posterior_samples,
839877
self._batch_shape,
840878
return_sites=self.return_sites,
879+
infer_discrete=self.infer_discrete,
841880
parallel=self.parallel,
842881
model_args=args,
843882
model_kwargs=kwargs,

0 commit comments

Comments
 (0)