Skip to content

Commit 17147bb

Browse files
committed
hmc_gibbs updated to work with different support sizes and batching, tests are passing when using changes from PR pyro-ppl#1859
1 parent fe46ba1 commit 17147bb

File tree

2 files changed

+73
-24
lines changed

2 files changed

+73
-24
lines changed

numpyro/infer/hmc_gibbs.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -476,26 +476,24 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
476476
# Each support is padded with zeros to have the same length
477477
# ravel is used to maintain a consistant behaviour with `support_sizes`
478478

479-
max_length_support_enumerates = max(
480-
size for size in self._support_sizes.values()
479+
max_length_support_enumerates = np.max(
480+
[size for size in self._support_sizes.values()]
481481
)
482482

483483
support_enumerates = {}
484484
for name, support_size in self._support_sizes.items():
485485
site = self._prototype_trace[name]
486-
enumerate_support = site["fn"].enumerate_support(False)
487-
padded_enumerate_support = np.pad(
488-
enumerate_support,
489-
(0, max_length_support_enumerates - enumerate_support.shape[0]),
490-
)
491-
padded_enumerate_support = np.broadcast_to(
492-
padded_enumerate_support,
493-
support_size.shape + (max_length_support_enumerates,),
494-
)
486+
enumerate_support = site["fn"].enumerate_support(True).T
487+
# Only the last dimension that corresponds to support size is padded
488+
pad_width = [(0, 0) for _ in range(len(enumerate_support.shape) - 1)] + [
489+
(0, max_length_support_enumerates - enumerate_support.shape[-1])
490+
]
491+
padded_enumerate_support = np.pad(enumerate_support, pad_width)
492+
495493
support_enumerates[name] = padded_enumerate_support
496494

497495
self._support_enumerates = jax.vmap(
498-
lambda x: ravel_pytree(x)[0], in_axes=0, out_axes=1
496+
lambda x: ravel_pytree(x)[0], in_axes=len(support_size.shape), out_axes=1
499497
)(support_enumerates)
500498

501499
self._gibbs_sites = [

test/test_distributions.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3430,26 +3430,30 @@ def test_discrete_uniform_with_mixedhmc():
34303430
import numpyro.distributions as dist
34313431
from numpyro.infer import HMC, MCMC, MixedHMC
34323432

3433-
def model_1():
3434-
numpyro.sample("x0", dist.DiscreteUniform(10, 12))
3435-
numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25])))
3436-
3433+
def sample_mixedhmc(model_fn, num_samples, **kwargs):
3434+
kernel = HMC(model_fn, trajectory_length=1.2)
3435+
kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs)
3436+
mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False)
3437+
key = jax.random.PRNGKey(0)
3438+
mcmc.run(key)
3439+
samples = mcmc.get_samples()
3440+
return samples
3441+
3442+
num_samples = 1000
34373443
mixed_hmc_kwargs = [
34383444
{"random_walk": False, "modified": False},
34393445
{"random_walk": True, "modified": False},
34403446
{"random_walk": True, "modified": True},
34413447
{"random_walk": False, "modified": True},
34423448
]
34433449

3444-
num_samples = 1000
3445-
3450+
# Case 1: one discrete uniform with one categorical
3451+
def model_1():
3452+
numpyro.sample("x0", dist.DiscreteUniform(10, 12))
3453+
numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25])))
3454+
34463455
for kwargs in mixed_hmc_kwargs:
3447-
kernel = HMC(model_1, trajectory_length=1.2)
3448-
kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs)
3449-
mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False)
3450-
key = jax.random.PRNGKey(0)
3451-
mcmc.run(key)
3452-
samples = mcmc.get_samples()
3456+
samples = sample_mixedhmc(model_1, num_samples, **kwargs)
34533457

34543458
assert jnp.all(
34553459
(samples["x0"] >= 10) & (samples["x0"] <= 12)
@@ -3458,6 +3462,53 @@ def model_1():
34583462
(samples["x1"] >= 0) & (samples["x1"] <= 3)
34593463
), f"Failed with {kwargs=}"
34603464

3465+
def model_2():
3466+
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((4,))))
3467+
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((10,))))
3468+
3469+
# Case 2: 2 categorical with different support lengths
3470+
for kwargs in mixed_hmc_kwargs:
3471+
samples = sample_mixedhmc(model_2, num_samples, **kwargs)
3472+
3473+
assert jnp.all(
3474+
(samples["x0"] >= 0) & (samples["x0"] <= 3)
3475+
), f"Failed with {kwargs=}"
3476+
assert jnp.all(
3477+
(samples["x1"] >= 0) & (samples["x1"] <= 9)
3478+
), f"Failed with {kwargs=}"
3479+
3480+
def model_3():
3481+
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((3, 4))))
3482+
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((3, 10))))
3483+
3484+
# Case 3: 2 categorical with different support lengths and batched by 3
3485+
for kwargs in mixed_hmc_kwargs:
3486+
samples = sample_mixedhmc(model_3, num_samples, **kwargs)
3487+
3488+
assert jnp.all(
3489+
(samples["x0"] >= 0) & (samples["x0"] <= 3)
3490+
), f"Failed with {kwargs=}"
3491+
assert jnp.all(
3492+
(samples["x1"] >= 0) & (samples["x1"] <= 9)
3493+
), f"Failed with {kwargs=}"
3494+
3495+
def model_4():
3496+
dist0 = dist.Categorical(0.25 * jnp.ones((3, 4)))
3497+
numpyro.sample("x0", dist0)
3498+
dist1 = dist.DiscreteUniform(10 * jnp.ones((3,)), 19 * jnp.ones((3,)))
3499+
numpyro.sample("x1", dist1)
3500+
3501+
# Case 4: 1 categorical with different support lengths and batched by 3
3502+
for kwargs in mixed_hmc_kwargs:
3503+
samples = sample_mixedhmc(model_4, num_samples, **kwargs)
3504+
3505+
assert jnp.all(
3506+
(samples["x0"] >= 0) & (samples["x0"] <= 3)
3507+
), f"Failed with {kwargs=}"
3508+
assert jnp.all(
3509+
(samples["x1"] >= 10) & (samples["x1"] <= 20)
3510+
), f"Failed with {kwargs=}"
3511+
34613512

34623513
if __name__ == "__main__":
34633514
test_discrete_uniform_with_mixedhmc()

0 commit comments

Comments
 (0)