-
Notifications
You must be signed in to change notification settings - Fork 145
Implement faster Multinomial JAX dispatch #1316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
af9803c
b551daf
f373d99
9e6ac0b
4172752
aef0d4b
81f0b12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from functools import singledispatch | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from numpy.random import Generator | ||
from numpy.random.bit_generator import ( # type: ignore[attr-defined] | ||
|
@@ -394,21 +395,47 @@ | |
|
||
@jax_sample_fn.register(ptr.MultinomialRV) | ||
def jax_sample_fn_multinomial(op, node): | ||
if not numpyro_available: | ||
raise NotImplementedError( | ||
f"No JAX implementation for the given distribution: {op.name}. " | ||
"Implementation is available if NumPyro is installed." | ||
) | ||
|
||
from numpyro.distributions.util import multinomial | ||
|
||
def sample_fn(rng_key, size, dtype, n, p): | ||
sample = multinomial(key=rng_key, n=n, p=p, shape=size) | ||
sample = _jax_multinomial(key=rng_key, n=n, p=p, size=size) | ||
return sample | ||
|
||
return sample_fn | ||
|
||
|
||
def _jax_multinomial(n, p, size=None, key=None): | ||
if size is not None: | ||
broadcast_shape_n = jax.lax.broadcast_shapes(jnp.shape(n), size) | ||
|
||
n = jnp.broadcast_to(n, broadcast_shape_n) | ||
|
||
broadcast_shape_p = jax.lax.broadcast_shapes(jnp.shape(p)[:-1], size) | ||
p = jnp.broadcast_to(p, broadcast_shape_p + jnp.shape(p)[-1:]) | ||
|
||
else: | ||
broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) | ||
n = jnp.broadcast_to(n, broadcast_shape) | ||
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) | ||
|
||
binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...] | ||
sampling_rng = jax.random.split(key, binom_p.shape[0]) | ||
ricardoV94 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
def _binomial_sample_fn(carry, p_rng): | ||
s, rho = carry | ||
p, rng = p_rng | ||
samples = jax.random.binomial(rng, s, p / rho) | ||
s = s - samples | ||
rho = rho - p | ||
return ((s, rho), samples) | ||
|
||
(remain, _), samples = jax.lax.scan( | ||
_binomial_sample_fn, | ||
(n.astype(np.float64), jnp.ones(binom_p.shape[1:])), | ||
(binom_p, sampling_rng), | ||
) | ||
return jnp.concatenate( | ||
[jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1 | ||
) | ||
ricardoV94 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
|
||
@jax_sample_fn.register(ptr.VonMisesRV) | ||
def jax_sample_fn_vonmises(op, node): | ||
if not numpyro_available: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can inline this in the dispatch for MultinomialRV. Also key=None is not really valid but we don't need default parameters anyway