diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 7d7358a5d..537f21504 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -469,9 +469,9 @@ def enumerate_support(self, expand=True): raise NotImplementedError( "Inhomogeneous `high` not supported by `enumerate_support`." ) - values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1)).reshape( - (-1,) + (1,) * len(self.batch_shape) - ) + low = jnp.reshape(self.low, -1)[0] + high = jnp.reshape(self.high, -1)[0] + values = jnp.arange(low, high + 1).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index f6b95389b..5a9c20beb 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -7,6 +7,7 @@ import numpy as np +import jax from jax import device_put, grad, jacfwd, random, value_and_grad from jax.flatten_util import ravel_pytree import jax.numpy as jnp @@ -192,12 +193,22 @@ def __getstate__(self): def _discrete_gibbs_proposal_body_fn( - z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val + z_init_flat, + unravel_fn, + pe_init, + potential_fn, + idx, + i, + val, + support_size, + support_enumerate, ): rng_key, z, pe, log_weight_sum = val rng_key, rng_transition = random.split(rng_key) - proposal = jnp.where(i >= z_init_flat[idx], i + 1, i) - z_new_flat = z_init_flat.at[idx].set(proposal) + proposal_index = jnp.where( + support_enumerate[i] == z_init_flat[idx], support_size - 1, i + ) + z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index]) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_weight_new = pe_init - pe_new @@ -216,7 +227,9 @@ def _discrete_gibbs_proposal_body_fn( return rng_key, z, pe, log_weight_sum -def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size): +def _discrete_gibbs_proposal( + rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate +): # idx: current index of `z_discrete_flat` to update # support_size: support size of z_discrete at the index idx @@ -234,6 +247,8 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support pe, potential_fn, idx, + support_size=support_size, + support_enumerate=support_enumerate, ) init_val = (rng_key, z_discrete, pe, jnp.array(0.0)) rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn, init_val) @@ -242,7 +257,14 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support def _discrete_modified_gibbs_proposal( - rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0 + rng_key, + z_discrete, + pe, + potential_fn, + idx, + support_size, + support_enumerate, + stay_prob=0.0, ): assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1 z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) @@ -253,6 +275,8 @@ def _discrete_modified_gibbs_proposal( pe, potential_fn, idx, + support_size=support_size, + support_enumerate=support_enumerate, ) # like gibbs_step but here, weight of the current value is 0 init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf)) @@ -276,12 +300,14 @@ def _discrete_modified_gibbs_proposal( return rng_key, z_new, pe_new, log_accept_ratio -def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size): +def _discrete_rw_proposal( + rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate +): rng_key, rng_proposal = random.split(rng_key, 2) z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size) - z_new_flat = z_discrete_flat.at[idx].set(proposal) + z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal]) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_accept_ratio = pe - pe_new @@ -289,15 +315,26 @@ def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_si def _discrete_modified_rw_proposal( - rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0 + rng_key, + z_discrete, + pe, + potential_fn, + idx, + support_size, + support_enumerate, + stay_prob=0.0, ): assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1 rng_key, rng_proposal, rng_stay = random.split(rng_key, 3) z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1) - proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i) - proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal) + proposal_index = jnp.where( + support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i + ) + proposal = jnp.where( + random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index] + ) z_new_flat = z_discrete_flat.at[idx].set(proposal) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) @@ -434,6 +471,31 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): and site["fn"].has_enumerate_support and not site["is_observed"] } + + # All support_enumerates should have the same length to be used in the loop + # Each support is padded with zeros to have the same length + # ravel is used to maintain a consistant behaviour with `support_sizes` + + max_length_support_enumerates = np.max( + [size for size in self._support_sizes.values()] + ) + + support_enumerates = {} + for name, support_size in self._support_sizes.items(): + site = self._prototype_trace[name] + enumerate_support = site["fn"].enumerate_support(True).T + # Only the last dimension that corresponds to support size is padded + pad_width = [(0, 0) for _ in range(len(enumerate_support.shape) - 1)] + [ + (0, max_length_support_enumerates - enumerate_support.shape[-1]) + ] + padded_enumerate_support = np.pad(enumerate_support, pad_width) + + support_enumerates[name] = padded_enumerate_support + + self._support_enumerates = jax.vmap( + lambda x: ravel_pytree(x)[0], in_axes=len(support_size.shape), out_axes=1 + )(support_enumerates) + self._gibbs_sites = [ name for name, site in self._prototype_trace.items() diff --git a/numpyro/infer/mixed_hmc.py b/numpyro/infer/mixed_hmc.py index 3e3d2ae59..e163fb75e 100644 --- a/numpyro/infer/mixed_hmc.py +++ b/numpyro/infer/mixed_hmc.py @@ -138,6 +138,7 @@ def update_discrete( partial(potential_fn, z_hmc=hmc_state.z), idx, self._support_sizes_flat[idx], + self._support_enumerates[idx], ) # Algo 1, line 20: depending on reject or refract, we will update # the discrete variable and its corresponding kinetic energy. In case of diff --git a/test/test_distributions.py b/test/test_distributions.py index e10fd7248..c1f2e5f5b 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3423,3 +3423,92 @@ def test_gaussian_random_walk_linear_recursive_equivalence(): x2 = dist2.sample(random.PRNGKey(7)) assert jnp.allclose(x1, x2.squeeze()) assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2)) + + +def test_discrete_uniform_with_mixedhmc(): + import numpyro + import numpyro.distributions as dist + from numpyro.infer import HMC, MCMC, MixedHMC + + def sample_mixedhmc(model_fn, num_samples, **kwargs): + kernel = HMC(model_fn, trajectory_length=1.2) + kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs) + mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False) + key = jax.random.PRNGKey(0) + mcmc.run(key) + samples = mcmc.get_samples() + return samples + + num_samples = 1000 + mixed_hmc_kwargs = [ + {"random_walk": False, "modified": False}, + {"random_walk": True, "modified": False}, + {"random_walk": True, "modified": True}, + {"random_walk": False, "modified": True}, + ] + + # Case 1: one discrete uniform with one categorical + def model_1(): + numpyro.sample("x0", dist.DiscreteUniform(10, 12)) + numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25]))) + + for kwargs in mixed_hmc_kwargs: + samples = sample_mixedhmc(model_1, num_samples, **kwargs) + + assert jnp.all( + (samples["x0"] >= 10) & (samples["x0"] <= 12) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 0) & (samples["x1"] <= 3) + ), f"Failed with {kwargs=}" + + def model_2(): + numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((4,)))) + numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((10,)))) + + # Case 2: 2 categorical with different support lengths + for kwargs in mixed_hmc_kwargs: + samples = sample_mixedhmc(model_2, num_samples, **kwargs) + + assert jnp.all( + (samples["x0"] >= 0) & (samples["x0"] <= 3) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 0) & (samples["x1"] <= 9) + ), f"Failed with {kwargs=}" + + def model_3(): + numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((3, 4)))) + numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((3, 10)))) + + # Case 3: 2 categorical with different support lengths and batched by 3 + for kwargs in mixed_hmc_kwargs: + samples = sample_mixedhmc(model_3, num_samples, **kwargs) + + assert jnp.all( + (samples["x0"] >= 0) & (samples["x0"] <= 3) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 0) & (samples["x1"] <= 9) + ), f"Failed with {kwargs=}" + + def model_4(): + dist0 = dist.Categorical(0.25 * jnp.ones((3, 4))) + numpyro.sample("x0", dist0) + dist1 = dist.DiscreteUniform(10 * jnp.ones((3,)), 19 * jnp.ones((3,))) + numpyro.sample("x1", dist1) + + # Case 4: 1 categorical with different support lengths and batched by 3 + for kwargs in mixed_hmc_kwargs: + samples = sample_mixedhmc(model_4, num_samples, **kwargs) + + assert jnp.all( + (samples["x0"] >= 0) & (samples["x0"] <= 3) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 10) & (samples["x1"] <= 20) + ), f"Failed with {kwargs=}" + + +if __name__ == "__main__": + test_discrete_uniform_with_mixedhmc()