diff --git a/examples/annotation.py b/examples/annotation.py index da35e89ed..6b7ada33e 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -37,11 +37,12 @@ import numpy as np -from jax import nn, random +from jax import nn, random, vmap import jax.numpy as jnp import numpyro from numpyro import handlers +from numpyro.contrib.funsor import config_enumerate, infer_discrete from numpyro.contrib.indexing import Vindex import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS @@ -312,6 +313,34 @@ def main(args): mcmc.run(random.PRNGKey(0), *data) mcmc.print_summary() + def infer_discrete_model(rng_key, samples): + conditioned_model = handlers.condition(model, data=samples) + infer_discrete_model = infer_discrete( + config_enumerate(conditioned_model), rng_key=rng_key + ) + with handlers.trace() as tr: + infer_discrete_model(*data) + + return { + name: site["value"] + for name, site in tr.items() + if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel" + } + + posterior_samples = mcmc.get_samples() + discrete_samples = vmap(infer_discrete_model)( + random.split(random.PRNGKey(1), args.num_samples), posterior_samples + ) + + item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)( + discrete_samples["c"].squeeze(-1) + ) + print("Histogram of the predicted class of each item:") + row_format = "{:>10}" * 5 + print(row_format.format("", *["c={}".format(i) for i in range(4)])) + for i, row in enumerate(item_class): + print(row_format.format(f"item[{i}]", *row)) + if __name__ == "__main__": assert numpyro.__version__.startswith("0.6.0") diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index 59ed4513a..72767ff84 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -1,16 +1,16 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict +from collections import OrderedDict, defaultdict import functools from jax import random +import jax.numpy as jnp import funsor -from numpyro.contrib.funsor.enum_messenger import enum, trace as packed_trace -from numpyro.contrib.funsor.infer_util import plate_to_enum_plate -from numpyro.distributions.util import is_identically_one -from numpyro.handlers import block, replay, seed, trace +from numpyro.contrib.funsor.enum_messenger import enum +from numpyro.contrib.funsor.infer_util import _enum_log_density, _get_shift, _shift_name +from numpyro.handlers import block, seed, substitute, trace from numpyro.infer.util import _guess_max_plate_nesting @@ -38,46 +38,6 @@ def _get_support_value_delta(funsor_dist, name, **kwargs): return OrderedDict(funsor_dist.terms)[name][0] -def terms_from_trace(tr): - """Helper function to extract elbo components from execution traces.""" - log_factors = {} - log_measures = {} - sum_vars, prod_vars = frozenset(), frozenset() - for site in tr.values(): - if site["type"] == "sample": - value = site["value"] - intermediates = site["intermediates"] - scale = site["scale"] - if intermediates: - log_prob = site["fn"].log_prob(value, intermediates) - else: - log_prob = site["fn"].log_prob(value) - - if (scale is not None) and (not is_identically_one(scale)): - log_prob = scale * log_prob - - dim_to_name = site["infer"]["dim_to_name"] - log_prob_factor = funsor.to_funsor( - log_prob, output=funsor.Real, dim_to_name=dim_to_name - ) - - if site["is_observed"]: - log_factors[site["name"]] = log_prob_factor - else: - log_measures[site["name"]] = log_prob_factor - sum_vars |= frozenset({site["name"]}) - prod_vars |= frozenset( - f.name for f in site["cond_indep_stack"] if f.dim is not None - ) - - return { - "log_factors": log_factors, - "log_measures": log_measures, - "measure_vars": sum_vars, - "plate_vars": prod_vars, - } - - def _sample_posterior( model, first_available_dim, temperature, rng_key, *args, **kwargs ): @@ -97,27 +57,14 @@ def _sample_posterior( model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 - with block(), enum(first_available_dim=first_available_dim): - with plate_to_enum_plate(): - model_tr = packed_trace(model).get_trace(*args, **kwargs) - - terms = terms_from_trace(model_tr) - # terms["log_factors"] = [log p(x) for each observed or latent sample site x] - # terms["log_measures"] = [log p(z) or other Dice factor - # for each latent sample site z] - - with funsor.interpretations.lazy: - log_prob = funsor.sum_product.sum_product( - sum_op, - prod_op, - list(terms["log_factors"].values()) + list(terms["log_measures"].values()), - eliminate=terms["measure_vars"] | terms["plate_vars"], - plates=terms["plate_vars"], - ) - log_prob = funsor.optimizer.apply_optimizer(log_prob) + with funsor.adjoint.AdjointTape() as tape: + with block(), enum(first_available_dim=first_available_dim): + log_prob, model_tr, log_measures = _enum_log_density( + model, args, kwargs, {}, sum_op, prod_op + ) with approx: - approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) + approx_factors = tape.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() @@ -138,13 +85,40 @@ def _sample_posterior( value, name_to_dim=node["infer"]["name_to_dim"] ) else: - log_measure = approx_factors[terms["log_measures"][name]] + log_measure = approx_factors[log_measures[name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"] ) - with replay(guide_trace=sample_tr): + data = { + name: site["value"] + for name, site in sample_tr.items() + if site["type"] == "sample" + } + + # concatenate _PREV_foo to foo + time_vars = defaultdict(list) + for name in data: + if name.startswith("_PREV_"): + root_name = _shift_name(name, -_get_shift(name)) + time_vars[root_name].append(name) + for name in time_vars: + if name in data: + time_vars[name].append(name) + time_vars[name] = sorted(time_vars[name], key=len, reverse=True) + + for root_name, vars in time_vars.items(): + prototype_shape = model_trace[root_name]["value"].shape + values = [data.pop(name) for name in vars] + if len(values) == 1: + data[root_name] = values[0].reshape(prototype_shape) + else: + assert len(prototype_shape) >= 1 + values = [v.reshape((-1,) + prototype_shape[1:]) for v in values] + data[root_name] = jnp.concatenate(values) + + with substitute(data=data): return model(*args, **kwargs) diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index 09d94b88f..5a65a2db5 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -100,6 +100,8 @@ def compute_markov_factors( sum_vars, prod_vars, history, + sum_op, + prod_op, ): """ :param dict time_to_factors: a map from time variable to the log prob factors. @@ -119,8 +121,8 @@ def compute_markov_factors( eliminate_vars = (sum_vars | prod_vars) - time_to_markov_dims[time_var] with funsor.interpretations.lazy: lazy_result = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, + sum_op, + prod_op, log_factors, eliminate=eliminate_vars, plates=prod_vars, @@ -136,7 +138,7 @@ def compute_markov_factors( ) markov_factors.append( funsor.sum_product.sarkka_bilmes_product( - funsor.ops.logaddexp, funsor.ops.add, trans, time_var, global_vars + sum_op, prod_op, trans, time_var, global_vars ) ) else: @@ -144,33 +146,14 @@ def compute_markov_factors( prev_to_curr = {k: _shift_name(k, -_get_shift(k)) for k in prev_vars} markov_factors.append( funsor.sum_product.sequential_sum_product( - funsor.ops.logaddexp, funsor.ops.add, trans, time_var, prev_to_curr + sum_op, prod_op, trans, time_var, prev_to_curr ) ) return markov_factors -def log_density(model, model_args, model_kwargs, params): - """ - Similar to :func:`numpyro.infer.util.log_density` but works for models - with discrete latent variables. Internally, this uses :mod:`funsor` - to marginalize discrete latent sites and evaluate the joint log probability. - - :param model: Python callable containing NumPyro primitives. Typically, - the model has been enumerated by using - :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler:: - - def model(*args, **kwargs): - ... - - log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params) - - :param tuple model_args: args provided to the model. - :param dict model_kwargs: kwargs provided to the model. - :param dict params: dictionary of current parameter values keyed by site - name. - :return: log of joint density and a corresponding model trace - """ +def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): + """Helper function to compute elbo and extract its components from execution traces.""" model = substitute(model, data=params) with plate_to_enum_plate(): model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) @@ -180,6 +163,7 @@ def model(*args, **kwargs): time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() history = 1 + log_measures = {} for site in model_trace.values(): if site["type"] == "sample": value = site["value"] @@ -214,7 +198,9 @@ def model(*args, **kwargs): log_factors.append(log_prob_factor) if not site["is_observed"]: + log_measures[site["name"]] = log_prob_factor sum_vars |= frozenset({site["name"]}) + prod_vars |= frozenset( f.name for f in site["cond_indep_stack"] if f.dim is not None ) @@ -236,13 +222,15 @@ def model(*args, **kwargs): sum_vars, prod_vars, history, + sum_op, + prod_op, ) log_factors = log_factors + markov_factors with funsor.interpretations.lazy: lazy_result = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, + sum_op, + prod_op, log_factors, eliminate=sum_vars | prod_vars, plates=prod_vars, @@ -255,4 +243,31 @@ def model(*args, **kwargs): result.data.shape, {k.split("__BOUND")[0] for k in result.inputs} ) ) + return result, model_trace, log_measures + + +def log_density(model, model_args, model_kwargs, params): + """ + Similar to :func:`numpyro.infer.util.log_density` but works for models + with discrete latent variables. Internally, this uses :mod:`funsor` + to marginalize discrete latent sites and evaluate the joint log probability. + + :param model: Python callable containing NumPyro primitives. Typically, + the model has been enumerated by using + :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler:: + + def model(*args, **kwargs): + ... + + log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params) + + :param tuple model_args: args provided to the model. + :param dict model_kwargs: kwargs provided to the model. + :param dict params: dictionary of current parameter values keyed by site + name. + :return: log of joint density and a corresponding model trace + """ + result, model_trace, _ = _enum_log_density( + model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add + ) return result.data, model_trace diff --git a/numpyro/handlers.py b/numpyro/handlers.py index e4d736a27..53be5116d 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -196,14 +196,21 @@ class replay(Messenger): >>> assert replayed_trace['a']['value'] == exec_trace['a']['value'] """ - def __init__(self, fn=None, guide_trace=None): - assert guide_trace is not None - self.guide_trace = guide_trace + def __init__(self, fn=None, trace=None, guide_trace=None): + if guide_trace is not None: + warnings.warn( + "`guide_trace` argument is deprecated. Please replace it by `trace`.", + FutureWarning, + ) + if guide_trace is not None: + trace = guide_trace + assert trace is not None + self.trace = trace super(replay, self).__init__(fn) def process_message(self, msg): - if msg["type"] in ("sample", "plate") and msg["name"] in self.guide_trace: - msg["value"] = self.guide_trace[msg["name"]]["value"] + if msg["type"] in ("sample", "plate") and msg["name"] in self.trace: + msg["value"] = self.trace[msg["name"]]["value"] class block(Messenger): diff --git a/test/contrib/test_infer_discrete.py b/test/contrib/test_infer_discrete.py index 2de58ad77..dd364531e 100644 --- a/test/contrib/test_infer_discrete.py +++ b/test/contrib/test_infer_discrete.py @@ -12,6 +12,7 @@ import numpyro from numpyro import handlers, infer +from numpyro.contrib.control_flow import scan import numpyro.distributions as dist from numpyro.distributions.util import is_identically_one @@ -81,6 +82,50 @@ def hmm(data, hidden_dim=10): logger.info("inferred states: {}".format(list(map(int, inferred_states)))) +@pytest.mark.parametrize( + "length", + [ + 1, + 2, + pytest.param( + 10, + marks=pytest.mark.xfail( + reason="adjoint does not work with markov sum product yet." + ), + ), + ], +) +@pytest.mark.parametrize("temperature", [0, 1]) +def test_scan_hmm_smoke(length, temperature): + + # This should match the example in the infer_discrete docstring. + def hmm(data, hidden_dim=10): + transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim) + means = jnp.arange(float(hidden_dim)) + + def transition_fn(state, y): + state = numpyro.sample("states", dist.Categorical(transition[state])) + y = numpyro.sample("obs", dist.Normal(means[state], 1.0), obs=y) + return state, (state, y) + + _, (states, data) = scan(transition_fn, 0, data, length=length) + + return [0] + [s for s in states], data + + true_states, data = handlers.seed(hmm, 0)(None) + assert len(data) == length + assert len(true_states) == 1 + len(data) + + decoder = infer_discrete( + config_enumerate(hmm), temperature=temperature, rng_key=random.PRNGKey(1) + ) + inferred_states, _ = decoder(data) + assert len(inferred_states) == len(true_states) + + logger.info("true states: {}".format(list(map(int, true_states)))) + logger.info("inferred states: {}".format(list(map(int, inferred_states)))) + + def vectorize_model(model, size, dim): def fn(*args, **kwargs): with numpyro.plate("particles", size=size, dim=dim): diff --git a/test/test_handlers.py b/test/test_handlers.py index ccdfe9c66..a58926278 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -383,7 +383,7 @@ def test_subsample_replay(): with numpyro.plate("a", len(data), subsample_size=subsample_size): pass - with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace): + with handlers.seed(rng_seed=1), handlers.replay(trace=guide_trace): with numpyro.plate("a", len(data)): subsample_data = numpyro.subsample(data, event_dim=0) assert subsample_data.shape == (subsample_size,)