Skip to content
82 changes: 72 additions & 10 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,22 @@ def __getstate__(self):


def _discrete_gibbs_proposal_body_fn(
z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val
z_init_flat,
unravel_fn,
pe_init,
potential_fn,
idx,
i,
val,
support_size,
support_enumerate,
):
rng_key, z, pe, log_weight_sum = val
rng_key, rng_transition = random.split(rng_key)
proposal = jnp.where(i >= z_init_flat[idx], i + 1, i)
z_new_flat = z_init_flat.at[idx].set(proposal)
proposal_index = jnp.where(
support_enumerate[i] == z_init_flat[idx], support_size - 1, i
)
z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
log_weight_new = pe_init - pe_new
Expand All @@ -216,7 +226,9 @@ def _discrete_gibbs_proposal_body_fn(
return rng_key, z, pe, log_weight_sum


def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
def _discrete_gibbs_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
):
# idx: current index of `z_discrete_flat` to update
# support_size: support size of z_discrete at the index idx

Expand All @@ -234,6 +246,8 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support
pe,
potential_fn,
idx,
support_size=support_size,
support_enumerate=support_enumerate,
)
init_val = (rng_key, z_discrete, pe, jnp.array(0.0))
rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn, init_val)
Expand All @@ -242,7 +256,14 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support


def _discrete_modified_gibbs_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
rng_key,
z_discrete,
pe,
potential_fn,
idx,
support_size,
support_enumerate,
stay_prob=0.0,
):
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
Expand All @@ -253,6 +274,8 @@ def _discrete_modified_gibbs_proposal(
pe,
potential_fn,
idx,
support_size=support_size,
support_enumerate=support_enumerate,
)
# like gibbs_step but here, weight of the current value is 0
init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf))
Expand All @@ -276,28 +299,41 @@ def _discrete_modified_gibbs_proposal(
return rng_key, z_new, pe_new, log_accept_ratio


def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size):
def _discrete_rw_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate
):
rng_key, rng_proposal = random.split(rng_key, 2)
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
log_accept_ratio = pe - pe_new
return rng_key, z_new, pe_new, log_accept_ratio


def _discrete_modified_rw_proposal(
rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0
rng_key,
z_discrete,
pe,
potential_fn,
idx,
support_size,
support_enumerate,
stay_prob=0.0,
):
assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1
rng_key, rng_proposal, rng_stay = random.split(rng_key, 3)
z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
proposal_index = jnp.where(
support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i
)
proposal = jnp.where(
random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index]
)
z_new_flat = z_discrete_flat.at[idx].set(proposal)
z_new = unravel_fn(z_new_flat)
pe_new = potential_fn(z_new)
Expand Down Expand Up @@ -434,6 +470,32 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
and site["fn"].has_enumerate_support
and not site["is_observed"]
}
max_length_support_enumerates = max(
(
site["fn"].enumerate_support(False).shape[0]
for site in self._prototype_trace.values()
if site["type"] == "sample"
and site["fn"].has_enumerate_support
and not site["is_observed"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is better to loop over support_sizes: for name, site in self._prototype_trace.items() if name in support_sizes

)
)
# All support_enumerates should have the same length to be used in the loop
# Each support is padded with zeros to have the same length
self._support_enumerates = np.zeros(
(len(self._support_sizes), max_length_support_enumerates), dtype=int
)
for i, (name, site) in enumerate(self._prototype_trace.items()):
Copy link
Member

@fehiepsi fehiepsi Jul 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great solution! I just have a couple of comments:

  • it might be better to loop over names in support_sizes and get site via site = self._prototype_trace[name]
  • we use ravel_pytree to flatten support_sizes. so we might want to keep the same behavior here. I don't have a great solution for this, maybe
support_enumerates = {}
for name, support_size in self._support_sizes.items():
    site = self._prototype_trace[name]
    enumerate_support = site["fn"].enumerate_support(False)
    padded_enumerate_support = np.pad(enumerate_support, (0, max_length_support_enumerates - enumerate_support.shape[0]))
    padded_enumerate_support = np.broadcast_to(padded_enumerate_support, support_size.shape + (max_length_support_enumerates,))
    support_enumerates[name] = padded_enumerate_support

self._support_enumerates = jax.vmap(lambda x: ravel_pytree(x)[0], in_axes=1, out_axes=1)(support_enumerates)

if (
site["type"] == "sample"
and site["fn"].has_enumerate_support
and not site["is_observed"]
):
self._support_enumerates[
i, : site["fn"].enumerate_support(False).shape[0]
] = site["fn"].enumerate_support(False)
self._support_enumerates = jnp.asarray(
self._support_enumerates, dtype=jnp.int32
)
self._gibbs_sites = [
name
for name, site in self._prototype_trace.items()
Expand Down
2 changes: 2 additions & 0 deletions numpyro/infer/mixed_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from jax import grad, jacfwd, lax, random
from jax.flatten_util import ravel_pytree
import jax
import jax.numpy as jnp

from numpyro.infer.hmc import momentum_generator
Expand Down Expand Up @@ -138,6 +139,7 @@ def update_discrete(
partial(potential_fn, z_hmc=hmc_state.z),
idx,
self._support_sizes_flat[idx],
self._support_enumerates[idx],
)
# Algo 1, line 20: depending on reject or refract, we will update
# the discrete variable and its corresponding kinetic energy. In case of
Expand Down