11from functools import singledispatch
22
33import jax
4+ import jax .numpy as jnp
45import numpy as np
56from numpy .random import Generator
67from numpy .random .bit_generator import ( # type: ignore[attr-defined]
@@ -429,19 +430,11 @@ def sample_fn(rng, size, dtype, n, p):
429430
430431@jax_sample_fn .register (ptr .MultinomialRV )
431432def jax_sample_fn_multinomial (op , node ):
432- if not numpyro_available :
433- raise NotImplementedError (
434- f"No JAX implementation for the given distribution: { op .name } . "
435- "Implementation is available if NumPyro is installed."
436- )
437-
438- from numpyro .distributions .util import multinomial
439-
440433 def sample_fn (rng , size , dtype , n , p ):
441434 rng_key = rng ["jax_state" ]
442435 rng_key , sampling_key = jax .random .split (rng_key , 2 )
443436
444- sample = multinomial (key = sampling_key , n = n , p = p , shape = size )
437+ sample = _jax_multinomial (key = sampling_key , n = n , p = p , shape = size )
445438
446439 rng ["jax_state" ] = rng_key
447440
@@ -450,6 +443,40 @@ def sample_fn(rng, size, dtype, n, p):
450443 return sample_fn
451444
452445
446+ def _jax_multinomial (n , p , shape = None , key = None ):
447+ if jnp .shape (n ) != jnp .shape (p )[:- 1 ]:
448+ broadcast_shape = jax .lax .broadcast_shapes (jnp .shape (n ), jnp .shape (p )[:- 1 ])
449+ n = jnp .broadcast_to (n , broadcast_shape )
450+ p = jnp .broadcast_to (p , broadcast_shape + jnp .shape (p )[- 1 :])
451+ if shape is not None :
452+ broadcast_shape = jax .lax .broadcast_shapes (jnp .shape (n ), shape )
453+ n = jnp .broadcast_to (n , broadcast_shape )
454+ else :
455+ shape = shape or p .shape [:- 1 ]
456+
457+ p = p / jnp .sum (p , axis = - 1 , keepdims = True )
458+ binom_p = jnp .moveaxis (p , - 1 , 0 )[:- 1 , ...]
459+
460+ sampling_rng = jax .random .split (key , binom_p .shape [0 ])
461+
462+ def _binomial_sample_fn (carry , p_rng ):
463+ s , rho = carry
464+ p , rng = p_rng
465+ samples = jax .random .binomial (rng , s , p / rho , shape )
466+ s = s - samples
467+ rho = rho - p
468+ return ((s , rho ), samples )
469+
470+ (remain , _ ), samples = jax .lax .scan (
471+ _binomial_sample_fn ,
472+ (n .astype ("float" ), jnp .ones (binom_p .shape [1 :])),
473+ (binom_p , sampling_rng ),
474+ )
475+ return jnp .concatenate (
476+ [jnp .moveaxis (samples , 0 , - 1 ), jnp .expand_dims (remain , - 1 )], axis = - 1
477+ )
478+
479+
453480@jax_sample_fn .register (ptr .VonMisesRV )
454481def jax_sample_fn_vonmises (op , node ):
455482 if not numpyro_available :
0 commit comments