Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import singledispatch

import jax
import jax.numpy as jnp
import numpy as np
from numpy.random import Generator
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
Expand Down Expand Up @@ -394,21 +395,47 @@

@jax_sample_fn.register(ptr.MultinomialRV)
def jax_sample_fn_multinomial(op, node):
if not numpyro_available:
raise NotImplementedError(
f"No JAX implementation for the given distribution: {op.name}. "
"Implementation is available if NumPyro is installed."
)

from numpyro.distributions.util import multinomial

def sample_fn(rng_key, size, dtype, n, p):
sample = multinomial(key=rng_key, n=n, p=p, shape=size)
sample = _jax_multinomial(key=rng_key, n=n, p=p, size=size)
return sample

return sample_fn


def _jax_multinomial(n, p, size=None, key=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can inline this in the dispatch for MultinomialRV. Also key=None is not really valid but we don't need default parameters anyway

if size is not None:
broadcast_shape_n = jax.lax.broadcast_shapes(jnp.shape(n), size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to broadcast size with n, that's not allowed by the RandomVariable semantics. If size is provided n must be broadcastable to size, but not the other way. You can's say n=ones((5, 1)), size=(1, 3)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, my bad

n = jnp.broadcast_to(n, broadcast_shape_n)

broadcast_shape_p = jax.lax.broadcast_shapes(jnp.shape(p)[:-1], size)
p = jnp.broadcast_to(p, broadcast_shape_p + jnp.shape(p)[-1:])

else:
broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])

Check warning on line 416 in pytensor/link/jax/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/random.py#L414-L416

Added lines #L414 - L416 were not covered by tests

binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...]
sampling_rng = jax.random.split(key, binom_p.shape[0])

def _binomial_sample_fn(carry, p_rng):
s, rho = carry
p, rng = p_rng
samples = jax.random.binomial(rng, s, p / rho)
s = s - samples
rho = rho - p
return ((s, rho), samples)

(remain, _), samples = jax.lax.scan(
_binomial_sample_fn,
(n.astype(np.float64), jnp.ones(binom_p.shape[1:])),
(binom_p, sampling_rng),
)
return jnp.concatenate(
[jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1
)


@jax_sample_fn.register(ptr.VonMisesRV)
def jax_sample_fn_vonmises(op, node):
if not numpyro_available:
Expand Down
20 changes: 16 additions & 4 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,21 +703,33 @@ def test_beta_binomial():
)


@pytest.mark.skipif(
not numpyro_available, reason="Multinomial dispatch requires numpyro"
)
def test_multinomial():
rng = shared(np.random.default_rng(123))

# test with 'size' argument and n.shape == p.shape[:-1]
n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
size = (10_000, 2)

g = pt.random.multinomial(n, p, size=size, rng=rng)
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose(
samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1
)

# test with no 'size' argument and n.shape != p.shape[:-1]
n = np.broadcast_to(np.array([10, 40]), size)

g = pt.random.multinomial(n, p, rng=rng)
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[0, :, None] * p, rtol=0.1)
np.testing.assert_allclose(
samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1
)


@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
def test_vonmises_mu_outside_circle():
Expand Down