Skip to content

Commit 17a5e42

Browse files
committed
Fix JAX implementation of Categorical
1 parent e9b56ae commit 17a5e42

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,20 @@ def sample_fn(rng, size, dtype, *parameters):
169169

170170

171171
@jax_sample_fn.register(ptr.BernoulliRV)
172+
def jax_sample_fn_bernoulli(op):
173+
"""JAX implementation of `BernoulliRV`."""
174+
175+
# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
176+
def sample_fn(rng, size, dtype, p):
177+
rng_key = rng["jax_state"]
178+
rng_key, sampling_key = jax.random.split(rng_key, 2)
179+
sample = jax.random.bernoulli(sampling_key, p, shape=size)
180+
rng["jax_state"] = rng_key
181+
return (rng, sample)
182+
183+
return sample_fn
184+
185+
172186
@jax_sample_fn.register(ptr.CategoricalRV)
173187
def jax_sample_fn_no_dtype(op):
174188
"""Generic JAX implementation of random variables."""

tests/link/jax/test_random.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,16 @@ def test_random_categorical():
595595
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
596596
g_fn = compile_random_function([], g, mode=jax_mode)
597597
samples = g_fn()
598+
assert samples.shape == (10000, 4)
598599
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)
599600

601+
# Test zero probabilities
602+
g = pt.random.categorical([0, 0.5, 0, 0.5], size=(1000,), rng=rng)
603+
g_fn = compile_random_function([], g, mode=jax_mode)
604+
samples = g_fn()
605+
assert samples.shape == (1000,)
606+
assert np.all(samples % 2 == 1)
607+
600608

601609
def test_random_permutation():
602610
array = np.arange(4)

0 commit comments

Comments
 (0)