@@ -3430,26 +3430,30 @@ def test_discrete_uniform_with_mixedhmc():
34303430 import numpyro .distributions as dist
34313431 from numpyro .infer import HMC , MCMC , MixedHMC
34323432
3433- def model_1 ():
3434- numpyro .sample ("x0" , dist .DiscreteUniform (10 , 12 ))
3435- numpyro .sample ("x1" , dist .Categorical (np .asarray ([0.25 , 0.25 , 0.25 , 0.25 ])))
3436-
3433+ def sample_mixedhmc (model_fn , num_samples , ** kwargs ):
3434+ kernel = HMC (model_fn , trajectory_length = 1.2 )
3435+ kernel = MixedHMC (kernel , num_discrete_updates = 20 , ** kwargs )
3436+ mcmc = MCMC (kernel , num_warmup = 100 , num_samples = num_samples , progress_bar = False )
3437+ key = jax .random .PRNGKey (0 )
3438+ mcmc .run (key )
3439+ samples = mcmc .get_samples ()
3440+ return samples
3441+
3442+ num_samples = 1000
34373443 mixed_hmc_kwargs = [
34383444 {"random_walk" : False , "modified" : False },
34393445 {"random_walk" : True , "modified" : False },
34403446 {"random_walk" : True , "modified" : True },
34413447 {"random_walk" : False , "modified" : True },
34423448 ]
34433449
3444- num_samples = 1000
3445-
3450+ # Case 1: one discrete uniform with one categorical
3451+ def model_1 ():
3452+ numpyro .sample ("x0" , dist .DiscreteUniform (10 , 12 ))
3453+ numpyro .sample ("x1" , dist .Categorical (np .asarray ([0.25 , 0.25 , 0.25 , 0.25 ])))
3454+
34463455 for kwargs in mixed_hmc_kwargs :
3447- kernel = HMC (model_1 , trajectory_length = 1.2 )
3448- kernel = MixedHMC (kernel , num_discrete_updates = 20 , ** kwargs )
3449- mcmc = MCMC (kernel , num_warmup = 100 , num_samples = num_samples , progress_bar = False )
3450- key = jax .random .PRNGKey (0 )
3451- mcmc .run (key )
3452- samples = mcmc .get_samples ()
3456+ samples = sample_mixedhmc (model_1 , num_samples , ** kwargs )
34533457
34543458 assert jnp .all (
34553459 (samples ["x0" ] >= 10 ) & (samples ["x0" ] <= 12 )
@@ -3458,6 +3462,53 @@ def model_1():
34583462 (samples ["x1" ] >= 0 ) & (samples ["x1" ] <= 3 )
34593463 ), f"Failed with { kwargs = } "
34603464
3465+ def model_2 ():
3466+ numpyro .sample ("x0" , dist .Categorical (0.25 * jnp .ones ((4 ,))))
3467+ numpyro .sample ("x1" , dist .Categorical (0.1 * jnp .ones ((10 ,))))
3468+
3469+ # Case 2: 2 categorical with different support lengths
3470+ for kwargs in mixed_hmc_kwargs :
3471+ samples = sample_mixedhmc (model_2 , num_samples , ** kwargs )
3472+
3473+ assert jnp .all (
3474+ (samples ["x0" ] >= 0 ) & (samples ["x0" ] <= 3 )
3475+ ), f"Failed with { kwargs = } "
3476+ assert jnp .all (
3477+ (samples ["x1" ] >= 0 ) & (samples ["x1" ] <= 9 )
3478+ ), f"Failed with { kwargs = } "
3479+
3480+ def model_3 ():
3481+ numpyro .sample ("x0" , dist .Categorical (0.25 * jnp .ones ((3 , 4 ))))
3482+ numpyro .sample ("x1" , dist .Categorical (0.1 * jnp .ones ((3 , 10 ))))
3483+
3484+ # Case 3: 2 categorical with different support lengths and batched by 3
3485+ for kwargs in mixed_hmc_kwargs :
3486+ samples = sample_mixedhmc (model_3 , num_samples , ** kwargs )
3487+
3488+ assert jnp .all (
3489+ (samples ["x0" ] >= 0 ) & (samples ["x0" ] <= 3 )
3490+ ), f"Failed with { kwargs = } "
3491+ assert jnp .all (
3492+ (samples ["x1" ] >= 0 ) & (samples ["x1" ] <= 9 )
3493+ ), f"Failed with { kwargs = } "
3494+
3495+ def model_4 ():
3496+ dist0 = dist .Categorical (0.25 * jnp .ones ((3 , 4 )))
3497+ numpyro .sample ("x0" , dist0 )
3498+ dist1 = dist .DiscreteUniform (10 * jnp .ones ((3 ,)), 19 * jnp .ones ((3 ,)))
3499+ numpyro .sample ("x1" , dist1 )
3500+
3501+ # Case 4: 1 categorical with different support lengths and batched by 3
3502+ for kwargs in mixed_hmc_kwargs :
3503+ samples = sample_mixedhmc (model_4 , num_samples , ** kwargs )
3504+
3505+ assert jnp .all (
3506+ (samples ["x0" ] >= 0 ) & (samples ["x0" ] <= 3 )
3507+ ), f"Failed with { kwargs = } "
3508+ assert jnp .all (
3509+ (samples ["x1" ] >= 10 ) & (samples ["x1" ] <= 20 )
3510+ ), f"Failed with { kwargs = } "
3511+
34613512
34623513if __name__ == "__main__" :
34633514 test_discrete_uniform_with_mixedhmc ()
0 commit comments