Skip to content
106 changes: 40 additions & 66 deletions numpyro/contrib/funsor/discrete.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from collections import OrderedDict, defaultdict
import functools

from jax import random
import jax.numpy as jnp

import funsor
from numpyro.contrib.funsor.enum_messenger import enum, trace as packed_trace
from numpyro.contrib.funsor.infer_util import plate_to_enum_plate
from numpyro.distributions.util import is_identically_one
from numpyro.handlers import block, replay, seed, trace
from numpyro.contrib.funsor.enum_messenger import enum
from numpyro.contrib.funsor.infer_util import _enum_log_density, _get_shift, _shift_name
from numpyro.handlers import block, seed, substitute, trace
from numpyro.infer.util import _guess_max_plate_nesting


Expand Down Expand Up @@ -38,46 +38,6 @@ def _get_support_value_delta(funsor_dist, name, **kwargs):
return OrderedDict(funsor_dist.terms)[name][0]


def terms_from_trace(tr):
"""Helper function to extract elbo components from execution traces."""
log_factors = {}
log_measures = {}
sum_vars, prod_vars = frozenset(), frozenset()
for site in tr.values():
if site["type"] == "sample":
value = site["value"]
intermediates = site["intermediates"]
scale = site["scale"]
if intermediates:
log_prob = site["fn"].log_prob(value, intermediates)
else:
log_prob = site["fn"].log_prob(value)

if (scale is not None) and (not is_identically_one(scale)):
log_prob = scale * log_prob

dim_to_name = site["infer"]["dim_to_name"]
log_prob_factor = funsor.to_funsor(
log_prob, output=funsor.Real, dim_to_name=dim_to_name
)

if site["is_observed"]:
log_factors[site["name"]] = log_prob_factor
else:
log_measures[site["name"]] = log_prob_factor
sum_vars |= frozenset({site["name"]})
prod_vars |= frozenset(
f.name for f in site["cond_indep_stack"] if f.dim is not None
)

return {
"log_factors": log_factors,
"log_measures": log_measures,
"measure_vars": sum_vars,
"plate_vars": prod_vars,
}


def _sample_posterior(
model, first_available_dim, temperature, rng_key, *args, **kwargs
):
Expand All @@ -97,27 +57,14 @@ def _sample_posterior(
model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs)
first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

with block(), enum(first_available_dim=first_available_dim):
with plate_to_enum_plate():
model_tr = packed_trace(model).get_trace(*args, **kwargs)

terms = terms_from_trace(model_tr)
# terms["log_factors"] = [log p(x) for each observed or latent sample site x]
# terms["log_measures"] = [log p(z) or other Dice factor
# for each latent sample site z]

with funsor.interpretations.lazy:
log_prob = funsor.sum_product.sum_product(
sum_op,
prod_op,
list(terms["log_factors"].values()) + list(terms["log_measures"].values()),
eliminate=terms["measure_vars"] | terms["plate_vars"],
plates=terms["plate_vars"],
)
log_prob = funsor.optimizer.apply_optimizer(log_prob)
with funsor.adjoint.AdjointTape() as tape:
with block(), enum(first_available_dim=first_available_dim):
log_prob, model_tr, log_measures = _enum_log_density(
model, args, kwargs, {}, sum_op, prod_op
)

with approx:
approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)
approx_factors = tape.adjoint(sum_op, prod_op, log_prob)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the line where Scatter and Deltas are introduced


# construct a result trace to replay against the model
sample_tr = model_tr.copy()
Expand All @@ -138,13 +85,40 @@ def _sample_posterior(
value, name_to_dim=node["infer"]["name_to_dim"]
)
else:
log_measure = approx_factors[terms["log_measures"][name]]
log_measure = approx_factors[log_measures[name]]
sample_subs[name] = _get_support_value(log_measure, name)
node["value"] = funsor.to_data(
sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]
)

with replay(guide_trace=sample_tr):
data = {
name: site["value"]
for name, site in sample_tr.items()
if site["type"] == "sample"
}

# concatenate _PREV_foo to foo
time_vars = defaultdict(list)
for name in data:
if name.startswith("_PREV_"):
root_name = _shift_name(name, -_get_shift(name))
time_vars[root_name].append(name)
for name in time_vars:
if name in data:
time_vars[name].append(name)
time_vars[name] = sorted(time_vars[name], key=len, reverse=True)

for root_name, vars in time_vars.items():
prototype_shape = model_trace[root_name]["value"].shape
values = [data.pop(name) for name in vars]
if len(values) == 1:
data[root_name] = values[0].reshape(prototype_shape)
else:
assert len(prototype_shape) >= 1
values = [v.reshape((-1,) + prototype_shape[1:]) for v in values]
data[root_name] = jnp.concatenate(values)

with substitute(data=data):
return model(*args, **kwargs)


Expand Down
69 changes: 42 additions & 27 deletions numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def compute_markov_factors(
sum_vars,
prod_vars,
history,
sum_op,
prod_op,
):
"""
:param dict time_to_factors: a map from time variable to the log prob factors.
Expand All @@ -119,8 +121,8 @@ def compute_markov_factors(
eliminate_vars = (sum_vars | prod_vars) - time_to_markov_dims[time_var]
with funsor.interpretations.lazy:
lazy_result = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
sum_op,
prod_op,
log_factors,
eliminate=eliminate_vars,
plates=prod_vars,
Expand All @@ -136,41 +138,22 @@ def compute_markov_factors(
)
markov_factors.append(
funsor.sum_product.sarkka_bilmes_product(
funsor.ops.logaddexp, funsor.ops.add, trans, time_var, global_vars
sum_op, prod_op, trans, time_var, global_vars
)
)
else:
# remove `_PREV_` prefix to convert prev to curr
prev_to_curr = {k: _shift_name(k, -_get_shift(k)) for k in prev_vars}
markov_factors.append(
funsor.sum_product.sequential_sum_product(
funsor.ops.logaddexp, funsor.ops.add, trans, time_var, prev_to_curr
sum_op, prod_op, trans, time_var, prev_to_curr
)
)
return markov_factors


def log_density(model, model_args, model_kwargs, params):
"""
Similar to :func:`numpyro.infer.util.log_density` but works for models
with discrete latent variables. Internally, this uses :mod:`funsor`
to marginalize discrete latent sites and evaluate the joint log probability.

:param model: Python callable containing NumPyro primitives. Typically,
the model has been enumerated by using
:class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::

def model(*args, **kwargs):
...

log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)

:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:return: log of joint density and a corresponding model trace
"""
def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
"""Helper function to compute elbo and extract its components from execution traces."""
model = substitute(model, data=params)
with plate_to_enum_plate():
model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
Expand All @@ -180,6 +163,7 @@ def model(*args, **kwargs):
time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites
sum_vars, prod_vars = frozenset(), frozenset()
history = 1
log_measures = {}
for site in model_trace.values():
if site["type"] == "sample":
value = site["value"]
Expand Down Expand Up @@ -214,7 +198,9 @@ def model(*args, **kwargs):
log_factors.append(log_prob_factor)

if not site["is_observed"]:
log_measures[site["name"]] = log_prob_factor
sum_vars |= frozenset({site["name"]})

prod_vars |= frozenset(
f.name for f in site["cond_indep_stack"] if f.dim is not None
)
Expand All @@ -236,13 +222,15 @@ def model(*args, **kwargs):
sum_vars,
prod_vars,
history,
sum_op,
prod_op,
)
log_factors = log_factors + markov_factors

with funsor.interpretations.lazy:
lazy_result = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
sum_op,
prod_op,
log_factors,
eliminate=sum_vars | prod_vars,
plates=prod_vars,
Expand All @@ -255,4 +243,31 @@ def model(*args, **kwargs):
result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}
)
)
return result, model_trace, log_measures


def log_density(model, model_args, model_kwargs, params):
"""
Similar to :func:`numpyro.infer.util.log_density` but works for models
with discrete latent variables. Internally, this uses :mod:`funsor`
to marginalize discrete latent sites and evaluate the joint log probability.

:param model: Python callable containing NumPyro primitives. Typically,
the model has been enumerated by using
:class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::

def model(*args, **kwargs):
...

log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)

:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:return: log of joint density and a corresponding model trace
"""
result, model_trace, _ = _enum_log_density(
model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
)
return result.data, model_trace
33 changes: 33 additions & 0 deletions test/contrib/test_infer_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpyro
from numpyro import handlers, infer
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist
from numpyro.distributions.util import is_identically_one

Expand Down Expand Up @@ -81,6 +82,38 @@ def hmm(data, hidden_dim=10):
logger.info("inferred states: {}".format(list(map(int, inferred_states))))


@pytest.mark.parametrize("length", [1, 2, 10])
@pytest.mark.parametrize("temperature", [0, 1])
def test_scan_hmm_smoke(length, temperature):

# This should match the example in the infer_discrete docstring.
def hmm(data, hidden_dim=10):
transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim)
means = jnp.arange(float(hidden_dim))

def transition_fn(state, y):
state = numpyro.sample("states", dist.Categorical(transition[state]))
y = numpyro.sample("obs", dist.Normal(means[state], 1.0), obs=y)
return state, (state, y)

_, (states, data) = scan(transition_fn, 0, data, length=length)

return [0] + [s for s in states], data

true_states, data = handlers.seed(hmm, 0)(None)
assert len(data) == length
assert len(true_states) == 1 + len(data)

decoder = infer_discrete(
config_enumerate(hmm), temperature=temperature, rng_key=random.PRNGKey(1)
)
inferred_states, _ = decoder(data)
assert len(inferred_states) == len(true_states)

logger.info("true states: {}".format(list(map(int, true_states))))
logger.info("inferred states: {}".format(list(map(int, inferred_states))))


def vectorize_model(model, size, dim):
def fn(*args, **kwargs):
with numpyro.plate("particles", size=size, dim=dim):
Expand Down