Skip to content

Commit fe46ba1

Browse files
committed
adding test for mixed hmc sampling of distribution discrete uniform
1 parent f4b9d99 commit fe46ba1

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

test/test_distributions.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3423,3 +3423,41 @@ def test_gaussian_random_walk_linear_recursive_equivalence():
34233423
x2 = dist2.sample(random.PRNGKey(7))
34243424
assert jnp.allclose(x1, x2.squeeze())
34253425
assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2))
3426+
3427+
3428+
def test_discrete_uniform_with_mixedhmc():
3429+
import numpyro
3430+
import numpyro.distributions as dist
3431+
from numpyro.infer import HMC, MCMC, MixedHMC
3432+
3433+
def model_1():
3434+
numpyro.sample("x0", dist.DiscreteUniform(10, 12))
3435+
numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25])))
3436+
3437+
mixed_hmc_kwargs = [
3438+
{"random_walk": False, "modified": False},
3439+
{"random_walk": True, "modified": False},
3440+
{"random_walk": True, "modified": True},
3441+
{"random_walk": False, "modified": True},
3442+
]
3443+
3444+
num_samples = 1000
3445+
3446+
for kwargs in mixed_hmc_kwargs:
3447+
kernel = HMC(model_1, trajectory_length=1.2)
3448+
kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs)
3449+
mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False)
3450+
key = jax.random.PRNGKey(0)
3451+
mcmc.run(key)
3452+
samples = mcmc.get_samples()
3453+
3454+
assert jnp.all(
3455+
(samples["x0"] >= 10) & (samples["x0"] <= 12)
3456+
), f"Failed with {kwargs=}"
3457+
assert jnp.all(
3458+
(samples["x1"] >= 0) & (samples["x1"] <= 3)
3459+
), f"Failed with {kwargs=}"
3460+
3461+
3462+
if __name__ == "__main__":
3463+
test_discrete_uniform_with_mixedhmc()

0 commit comments

Comments
 (0)