1111from jax import device_get , jacfwd , lax , random , value_and_grad
1212from jax .flatten_util import ravel_pytree
1313import jax .numpy as jnp
14+ from jax .tree_util import tree_map
1415
1516import numpyro
1617from numpyro .distributions import constraints
1718from numpyro .distributions .transforms import biject_to
1819from numpyro .distributions .util import is_identically_one , sum_rightmost
19- from numpyro .handlers import replay , seed , substitute , trace
20+ from numpyro .handlers import condition , replay , seed , substitute , trace
2021from numpyro .infer .initialization import init_to_uniform , init_to_value
2122from numpyro .util import not_jax_tracer , soft_vmap , while_loop
2223
@@ -673,17 +674,47 @@ def _predictive(
673674 posterior_samples ,
674675 batch_shape ,
675676 return_sites = None ,
677+ infer_discrete = False ,
676678 parallel = True ,
677679 model_args = (),
678680 model_kwargs = {},
679681):
680- model = numpyro .handlers .mask (model , mask = False )
682+ masked_model = numpyro .handlers .mask (model , mask = False )
683+ if infer_discrete :
684+ # inspect the model to get some structure
685+ rng_key , subkey = random .split (rng_key )
686+ batch_ndim = len (batch_shape )
687+ prototype_sample = tree_map (
688+ lambda x : jnp .reshape (x , (- 1 ,) + jnp .shape (x )[batch_ndim :])[0 ],
689+ posterior_samples ,
690+ )
691+ prototype_trace = trace (
692+ seed (substitute (masked_model , prototype_sample ), subkey )
693+ ).get_trace (* model_args , ** model_kwargs )
694+ first_available_dim = - _guess_max_plate_nesting (prototype_trace ) - 1
681695
682696 def single_prediction (val ):
683697 rng_key , samples = val
684- model_trace = trace (seed (substitute (model , samples ), rng_key )).get_trace (
685- * model_args , ** model_kwargs
686- )
698+ if infer_discrete :
699+ from numpyro .contrib .funsor import config_enumerate
700+ from numpyro .contrib .funsor .discrete import _sample_posterior
701+
702+ model_trace = prototype_trace
703+ temperature = 1
704+ pred_samples = _sample_posterior (
705+ config_enumerate (condition (model , samples )),
706+ first_available_dim ,
707+ temperature ,
708+ rng_key ,
709+ * model_args ,
710+ ** model_kwargs ,
711+ )
712+ else :
713+ model_trace = trace (
714+ seed (substitute (masked_model , samples ), rng_key )
715+ ).get_trace (* model_args , ** model_kwargs )
716+ pred_samples = {name : site ["value" ] for name , site in model_trace .items ()}
717+
687718 if return_sites is not None :
688719 if return_sites == "" :
689720 sites = {
@@ -698,9 +729,7 @@ def single_prediction(val):
698729 if (site ["type" ] == "sample" and k not in samples )
699730 or (site ["type" ] == "deterministic" )
700731 }
701- return {
702- name : site ["value" ] for name , site in model_trace .items () if name in sites
703- }
732+ return {name : value for name , value in pred_samples .items () if name in sites }
704733
705734 num_samples = int (np .prod (batch_shape ))
706735 if num_samples > 1 :
@@ -729,6 +758,12 @@ class Predictive(object):
729758 :param int num_samples: number of samples
730759 :param list return_sites: sites to return; by default only sample sites not present
731760 in `posterior_samples` are returned.
761+ :param bool infer_discrete: whether or not to sample discrete sites from the
762+ posterior, conditioned on observations and other latent values in
763+ ``posterior_samples``. Under the hood, those sites will be marked with
764+ ``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at
765+ the `Pyro enumeration tutorial <https://pyro.ai/examples/enumeration.html>`_.
766+ Note that this requires ``funsor`` installation.
732767 :param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`.
733768 Defaults to False.
734769 :param batch_ndims: the number of batch dimensions in posterior samples. Some usages:
@@ -749,10 +784,12 @@ def __init__(
749784 self ,
750785 model ,
751786 posterior_samples = None ,
787+ * ,
752788 guide = None ,
753789 params = None ,
754790 num_samples = None ,
755791 return_sites = None ,
792+ infer_discrete = False ,
756793 parallel = False ,
757794 batch_ndims = 1 ,
758795 ):
@@ -801,6 +838,7 @@ def __init__(
801838 self .num_samples = num_samples
802839 self .guide = guide
803840 self .params = {} if params is None else params
841+ self .infer_discrete = infer_discrete
804842 self .return_sites = return_sites
805843 self .parallel = parallel
806844 self .batch_ndims = batch_ndims
@@ -838,6 +876,7 @@ def __call__(self, rng_key, *args, **kwargs):
838876 posterior_samples ,
839877 self ._batch_shape ,
840878 return_sites = self .return_sites ,
879+ infer_discrete = self .infer_discrete ,
841880 parallel = self .parallel ,
842881 model_args = args ,
843882 model_kwargs = kwargs ,
0 commit comments