Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import singledispatch

import jax
import jax.numpy as jnp
import numpy as np
from numpy.random import Generator
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
Expand Down Expand Up @@ -429,19 +430,11 @@ def sample_fn(rng, size, dtype, n, p):

@jax_sample_fn.register(ptr.MultinomialRV)
def jax_sample_fn_multinomial(op, node):
if not numpyro_available:
raise NotImplementedError(
f"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)

from numpyro.distributions.util import multinomial

def sample_fn(rng, size, dtype, n, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)

sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
sample = _jax_multinomial(key=sampling_key, n=n, p=p, shape=size)

rng["jax_state"] = rng_key

Expand All @@ -450,6 +443,42 @@ def sample_fn(rng, size, dtype, n, p):
return sample_fn


def _jax_multinomial(n, p, shape=None, key=None):
if jnp.shape(n) != jnp.shape(p)[:-1]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theres' some redundancy here. If size (not shape btw) is provided we should just broadcast n to size and p to size + p.shape[-1]. Only if size is not provided should we broadcast n and p[...: -1] together (basically finding the implicit size)

Copy link
Contributor Author

@educhesne educhesne Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, besides there is actually no need to broadcast p at all; jax.random.binomial broadcasts it for us.
About the name shape: I kept the signature of the numpyro function but I can change it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might be right about p although I guess it's more readable if you broadcast explicitly

broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])

if shape is not None:
broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), shape)
n = jnp.broadcast_to(n, broadcast_shape)

else:
shape = p.shape[:-1]

p = p / jnp.sum(p, axis=-1, keepdims=True)
binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...]

sampling_rng = jax.random.split(key, binom_p.shape[0])

def _binomial_sample_fn(carry, p_rng):
s, rho = carry
p, rng = p_rng
samples = jax.random.binomial(rng, s, p / rho, shape)
s = s - samples
rho = rho - p
return ((s, rho), samples)

(remain, _), samples = jax.lax.scan(
_binomial_sample_fn,
(n.astype(np.float64), jnp.ones(binom_p.shape[1:])),
(binom_p, sampling_rng),
)
return jnp.concatenate(
[jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1
)


@jax_sample_fn.register(ptr.VonMisesRV)
def jax_sample_fn_vonmises(op, node):
if not numpyro_available:
Expand Down
20 changes: 16 additions & 4 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,21 +703,33 @@ def test_beta_binomial():
)


@pytest.mark.skipif(
not numpyro_available, reason="Multinomial dispatch requires numpyro"
)
def test_multinomial():
rng = shared(np.random.default_rng(123))

# test with 'size' argument and n.shape == p.shape[:-1]
n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
size = (10_000, 2)

g = pt.random.multinomial(n, p, size=size, rng=rng)
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose(
samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1
)

# test with no 'size' argument and n.shape != p.shape[:-1]
n = np.broadcast_to(np.array([10, 40]), size)

g = pt.random.multinomial(n, p, rng=rng)
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[0, :, None] * p, rtol=0.1)
np.testing.assert_allclose(
samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1
)


@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
def test_vonmises_mu_outside_circle():
Expand Down