Skip to content
90 changes: 80 additions & 10 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
from functools import partial

import jax
import numpy as np

from jax import device_put, grad, jacfwd, random, value_and_grad
Expand Down Expand Up @@ -192,12 +193,22 @@ def __getstate__(self):


def _discrete_gibbs_proposal_body_fn(
z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val
z_init_flat,
unravel_fn,
pe_init,
potential_fn,
idx,
i,
val,
support_size,
support_enumerate,
):
rng_key, z, pe, log_weight_sum = val
rng_key, rng_transition = random.split(rng_key)
proposal = jnp.where(i >= z_init_flat[idx], i + 1, i)
z_new_flat = z_init_flat.at[idx].set(proposal)
proposal_index = jnp.where(
support_enumerate[i] == z_init_flat[idx], support_size - 1, i
)
z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
log_weight_new = pe_init - pe_new
Expand All @@ -216,7 +227,9 @@ def _discrete_gibbs_proposal_body_fn(
return rng_key, z, pe, log_weight_sum


def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
def _discrete_gibbs_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
):
# idx: current index of `z_discrete_flat` to update
# support_size: support size of z_discrete at the index idx

Expand All @@ -234,6 +247,8 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support
pe,
potential_fn,
idx,
support_size=support_size,
support_enumerate=support_enumerate,
)
init_val = (rng_key, z_discrete, pe, jnp.array(0.0))
rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn, init_val)
Expand All @@ -242,7 +257,14 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support


def _discrete_modified_gibbs_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
rng_key,
z_discrete,
pe,
potential_fn,
idx,
support_size,
support_enumerate,
stay_prob=0.0,
):
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
Expand All @@ -253,6 +275,8 @@ def _discrete_modified_gibbs_proposal(
pe,
potential_fn,
idx,
support_size=support_size,
support_enumerate=support_enumerate,
)
# like gibbs_step but here, weight of the current value is 0
init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf))
Expand All @@ -276,28 +300,41 @@ def _discrete_modified_gibbs_proposal(
return rng_key, z_new, pe_new, log_accept_ratio


def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
def _discrete_rw_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
):
rng_key, rng_proposal = random.split(rng_key, 2)
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
log_accept_ratio = pe - pe_new
return rng_key, z_new, pe_new, log_accept_ratio


def _discrete_modified_rw_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
rng_key,
z_discrete,
pe,
potential_fn,
idx,
support_size,
support_enumerate,
stay_prob=0.0,
):
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
rng_key, rng_proposal, rng_stay = random.split(rng_key, 3)
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
proposal_index = jnp.where(
support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i
)
proposal = jnp.where(
random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index]
)
z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
Expand Down Expand Up @@ -434,6 +471,39 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
and site["fn"].has_enumerate_support
and not site["is_observed"]
}

# All support_enumerates should have the same length to be used in the loop
# Each support is padded with zeros to have the same length
# ravel is used to maintain a consistant behaviour with `support_sizes`

max_length_support_enumerates = max(
(
site["fn"].enumerate_support(False).shape[0]
for site in self._prototype_trace.values()
if site["type"] == "sample"
and site["fn"].has_enumerate_support
and not site["is_observed"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is better to loop over support_sizes: for name, site in self._prototype_trace.items() if name in support_sizes

)
)

support_enumerates = {}
for name, support_size in self._support_sizes.items():
site = self._prototype_trace[name]
enumerate_support = site["fn"].enumerate_support(False)
padded_enumerate_support = np.pad(
enumerate_support,
(0, max_length_support_enumerates - enumerate_support.shape[0]),
)
padded_enumerate_support = np.broadcast_to(
padded_enumerate_support,
support_size.shape + (max_length_support_enumerates,),
)
support_enumerates[name] = padded_enumerate_support

self._support_enumerates = jax.vmap(
lambda x: ravel_pytree(x)[0] , in_axes=0, out_axes=1
)(support_enumerates)

self._gibbs_sites = [
name
for name, site in self._prototype_trace.items()
Expand Down
2 changes: 2 additions & 0 deletions numpyro/infer/mixed_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from jax import grad, jacfwd, lax, random
from jax.flatten_util import ravel_pytree
import jax
import jax.numpy as jnp

from numpyro.infer.hmc import momentum_generator
Expand Down Expand Up @@ -138,6 +139,7 @@ def update_discrete(
partial(potential_fn, z_hmc=hmc_state.z),
idx,
self._support_sizes_flat[idx],
self._support_enumerates[idx],
)
# Algo 1, line 20: depending on reject or refract, we will update
# the discrete variable and its corresponding kinetic energy. In case of
Expand Down