@@ -3423,3 +3423,41 @@ def test_gaussian_random_walk_linear_recursive_equivalence():
34233423 x2 = dist2 .sample (random .PRNGKey (7 ))
34243424 assert jnp .allclose (x1 , x2 .squeeze ())
34253425 assert jnp .allclose (dist1 .log_prob (x1 ), dist2 .log_prob (x2 ))
3426+
3427+
3428+ def test_discrete_uniform_with_mixedhmc ():
3429+ import numpyro
3430+ import numpyro .distributions as dist
3431+ from numpyro .infer import HMC , MCMC , MixedHMC
3432+
3433+ def model_1 ():
3434+ numpyro .sample ("x0" , dist .DiscreteUniform (10 , 12 ))
3435+ numpyro .sample ("x1" , dist .Categorical (np .asarray ([0.25 , 0.25 , 0.25 , 0.25 ])))
3436+
3437+ mixed_hmc_kwargs = [
3438+ {"random_walk" : False , "modified" : False },
3439+ {"random_walk" : True , "modified" : False },
3440+ {"random_walk" : True , "modified" : True },
3441+ {"random_walk" : False , "modified" : True },
3442+ ]
3443+
3444+ num_samples = 1000
3445+
3446+ for kwargs in mixed_hmc_kwargs :
3447+ kernel = HMC (model_1 , trajectory_length = 1.2 )
3448+ kernel = MixedHMC (kernel , num_discrete_updates = 20 , ** kwargs )
3449+ mcmc = MCMC (kernel , num_warmup = 100 , num_samples = num_samples , progress_bar = False )
3450+ key = jax .random .PRNGKey (0 )
3451+ mcmc .run (key )
3452+ samples = mcmc .get_samples ()
3453+
3454+ assert jnp .all (
3455+ (samples ["x0" ] >= 10 ) & (samples ["x0" ] <= 12 )
3456+ ), f"Failed with { kwargs = } "
3457+ assert jnp .all (
3458+ (samples ["x1" ] >= 0 ) & (samples ["x1" ] <= 3 )
3459+ ), f"Failed with { kwargs = } "
3460+
3461+
3462+ if __name__ == "__main__" :
3463+ test_discrete_uniform_with_mixedhmc ()
0 commit comments