Skip to content

Commit af9803c

Browse files
author
Etienne Duchesne
committed
Implement multinomial JAX dispatch directly in jax
Replace the call to numpyro.distributions.util.multinomial by a custom function
1 parent 95ce102 commit af9803c

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import singledispatch
22

33
import jax
4+
import jax.numpy as jnp
45
import numpy as np
56
from numpy.random import Generator
67
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
@@ -429,19 +430,11 @@ def sample_fn(rng, size, dtype, n, p):
429430

430431
@jax_sample_fn.register(ptr.MultinomialRV)
431432
def jax_sample_fn_multinomial(op, node):
432-
if not numpyro_available:
433-
raise NotImplementedError(
434-
f"No JAX implementation for the given distribution: {op.name}. "
435-
"Implementation is available if NumPyro is installed."
436-
)
437-
438-
from numpyro.distributions.util import multinomial
439-
440433
def sample_fn(rng, size, dtype, n, p):
441434
rng_key = rng["jax_state"]
442435
rng_key, sampling_key = jax.random.split(rng_key, 2)
443436

444-
sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
437+
sample = _jax_multinomial(key=sampling_key, n=n, p=p, shape=size)
445438

446439
rng["jax_state"] = rng_key
447440

@@ -450,6 +443,40 @@ def sample_fn(rng, size, dtype, n, p):
450443
return sample_fn
451444

452445

446+
def _jax_multinomial(n, p, shape=None, key=None):
447+
if jnp.shape(n) != jnp.shape(p)[:-1]:
448+
broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
449+
n = jnp.broadcast_to(n, broadcast_shape)
450+
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
451+
if shape is not None:
452+
broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), shape)
453+
n = jnp.broadcast_to(n, broadcast_shape)
454+
else:
455+
shape = shape or p.shape[:-1]
456+
457+
p = p / jnp.sum(p, axis=-1, keepdims=True)
458+
binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...]
459+
460+
sampling_rng = jax.random.split(key, binom_p.shape[0])
461+
462+
def _binomial_sample_fn(carry, p_rng):
463+
s, rho = carry
464+
p, rng = p_rng
465+
samples = jax.random.binomial(rng, s, p / rho, shape)
466+
s = s - samples
467+
rho = rho - p
468+
return ((s, rho), samples)
469+
470+
(remain, _), samples = jax.lax.scan(
471+
_binomial_sample_fn,
472+
(n.astype("float"), jnp.ones(binom_p.shape[1:])),
473+
(binom_p, sampling_rng),
474+
)
475+
return jnp.concatenate(
476+
[jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1
477+
)
478+
479+
453480
@jax_sample_fn.register(ptr.VonMisesRV)
454481
def jax_sample_fn_vonmises(op, node):
455482
if not numpyro_available:

0 commit comments

Comments
 (0)