Skip to content
7 changes: 7 additions & 0 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,13 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
and site["fn"].has_enumerate_support
and not site["is_observed"]
}
self._support_enumerates = {
name: site["fn"].enumerate_support(False)
for name, site in self._prototype_trace.items()
if site["type"] == "sample"
and site["fn"].has_enumerate_support
and not site["is_observed"]
}
self._gibbs_sites = [
name
for name, site in self._prototype_trace.items()
Expand Down
6 changes: 6 additions & 0 deletions numpyro/infer/mixed_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from jax import grad, jacfwd, lax, random
from jax.flatten_util import ravel_pytree
import jax
import jax.numpy as jnp

from numpyro.infer.hmc import momentum_generator
Expand Down Expand Up @@ -301,6 +302,11 @@ def body_fn(i, vals):
adapt_state=adapt_state,
)

z_discrete = jax.tree.map(
lambda idx, support: support[idx],
z_discrete,
self._support_enumerates,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this might return in-support values but I worry that the algorithms are wrong. To compute potential energy correctly in the algorithm, we need to work with in-support values. I think you can pass support_enumerates into self._discrete_proposal_fn and change the proposal logic there.

    proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size)
    # z_new_flat = z_discrete_flat.at[idx].set(proposal)
    z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal])

or for modified rw proposal

    i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
    # proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
    # proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
    proposal_index = jnp.where(support_size[i] == z_discrete_flat[idx], support_size - 1, i)
    proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, support_size[proposal_index])
    z_new_flat = z_discrete_flat.at[idx].set(proposal)

or at discrete gibbs proposal

    proposal_index = jnp.where(support_enumerate[i] == z_init_flat[idx], support_size - 1, i)
    z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index])

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thank you for the feedback. I will try this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi how do you debug in numpyro? I tried jax.debug. but nothing happens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use print most of the time. When actual values are needed, I sometimes use jax.disable_jit()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi I have issues with passing enumerate supports and traced values as the support arrays can have different sizes. I was thinking maybe to just pass the "lower bound of the support" as offset and combined with support_sizes it should make the trick. Are there discrete variables where the support is not a simple discrete range with step 1 between values?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for modified_rw_proposal I think you used support_size in place of support_enumerate, shouldn't it be:

    i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
    # proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
    # proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
    proposal_index = jnp.where(support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i)
    proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index])
    z_new_flat = z_discrete_flat.at[idx].set(proposal)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! your solutions are super cool! I haven't thought of different support sizes previously.

z = {**z_discrete, **hmc_state.z}
return MixedHMCState(z, hmc_state, rng_key, accept_prob)

Expand Down