Skip to content

Commit ebcae0d

Browse files
Merge pull request jax-ml#26980 from carlosgmartin:categorical_replace
PiperOrigin-RevId: 737720590
2 parents be5d13a + 3f59fa6 commit ebcae0d

File tree

3 files changed

+80
-15
lines changed

3 files changed

+80
-15
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2222
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
2323
true, matching the current behavior. If set to false, JAX does not need to
2424
emit code clamping negative indices, which improves code size.
25+
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
26+
without replacement.
2527

2628
## jax 0.5.2 (Mar 4, 2025)
2729

jax/_src/random.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array:
15481548
_uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
15491549

15501550

1551-
def categorical(key: ArrayLike,
1552-
logits: RealArray,
1553-
axis: int = -1,
1554-
shape: Shape | None = None) -> Array:
1551+
def categorical(
1552+
key: ArrayLike,
1553+
logits: RealArray,
1554+
axis: int = -1,
1555+
shape: Shape | None = None,
1556+
replace: bool = True,
1557+
) -> Array:
15551558
"""Sample random values from categorical distributions.
15561559
1560+
Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
1561+
the Gumbel top-k trick. See [1] for reference.
1562+
15571563
Args:
15581564
key: a PRNG key used as the random key.
15591565
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
@@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike,
15621568
shape: Optional, a tuple of nonnegative integers representing the result shape.
15631569
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
15641570
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
1571+
replace: If True, perform sampling without replacement. Default (False) is to
1572+
perform sampling with replacement.
15651573
15661574
Returns:
15671575
A random array with int dtype and shape given by ``shape`` if ``shape``
15681576
is not None, or else ``np.delete(logits.shape, axis)``.
1577+
1578+
References:
1579+
.. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find
1580+
Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement".
1581+
Proceedings of the 36th International Conference on Machine Learning, PMLR
1582+
97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html.
15691583
"""
15701584
key, _ = _check_prng_key("categorical", key)
15711585
check_arraylike("categorical", logits)
15721586
logits_arr = jnp.asarray(logits)
1573-
1574-
if axis >= 0:
1575-
axis -= len(logits_arr.shape)
1576-
15771587
batch_shape = tuple(np.delete(logits_arr.shape, axis))
15781588
if shape is None:
15791589
shape = batch_shape
15801590
else:
15811591
shape = core.canonicalize_shape(shape)
15821592
_check_shape("categorical", shape, batch_shape)
1583-
15841593
shape_prefix = shape[:len(shape)-len(batch_shape)]
1585-
logits_shape = list(shape[len(shape) - len(batch_shape):])
1586-
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
1587-
return jnp.argmax(
1588-
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
1589-
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
1590-
axis=axis)
1594+
1595+
if replace:
1596+
if axis >= 0:
1597+
axis -= len(logits_arr.shape)
1598+
1599+
logits_shape = list(shape[len(shape) - len(batch_shape):])
1600+
logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis])
1601+
return jnp.argmax(
1602+
gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) +
1603+
lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))),
1604+
axis=axis)
1605+
else:
1606+
logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype)
1607+
k = math.prod(shape_prefix)
1608+
if k > logits_arr.shape[axis]:
1609+
raise ValueError(
1610+
f"Number of samples without replacement ({k}) cannot exceed number of "
1611+
f"categories ({logits_arr.shape[axis]})."
1612+
)
1613+
1614+
_, indices = lax.top_k(jnp.moveaxis(logits_arr, axis, -1), k)
1615+
assert indices.shape == batch_shape + (k,)
1616+
assert shape == shape_prefix + batch_shape
1617+
1618+
dimensions = (indices.ndim - 1, *range(indices.ndim - 1))
1619+
indices = lax.reshape(indices, shape, dimensions)
1620+
assert indices.shape == shape
1621+
return indices
15911622

15921623

15931624
def laplace(key: ArrayLike,

tests/random_lax_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,38 @@ def testCategorical(self, p, axis, dtype, sample_shape):
365365
pmf = lambda x: np.where(x < len(p), p[np.minimum(len(p) - 1, x)], 0.0)
366366
self._CheckChiSquared(samples, pmf=pmf)
367367

368+
@jtu.sample_product(
369+
logits_shape=[(7,), (8, 9), (10, 11, 12)],
370+
prefix_shape=[(2,), (3, 4), (5, 6)],
371+
)
372+
def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape):
373+
key = random.key(0)
374+
375+
key, subkey = random.split(key)
376+
logits = random.normal(subkey, logits_shape)
377+
378+
key, subkey = random.split(key)
379+
axis = random.randint(subkey, (), -len(logits_shape), len(logits_shape))
380+
381+
dists_shape = tuple(np.delete(logits_shape, axis))
382+
n_categories = logits_shape[axis]
383+
shape = prefix_shape + dists_shape
384+
prefix_size = math.prod(prefix_shape)
385+
386+
if n_categories < prefix_size:
387+
with self.assertRaisesRegex(ValueError, "Number of samples without replacement"):
388+
random.categorical(key, logits, axis=axis, shape=shape, replace=False)
389+
390+
else:
391+
output = random.categorical(key, logits, axis=axis, shape=shape, replace=False)
392+
self.assertEqual(output.shape, shape)
393+
assert (0 <= output).all()
394+
assert (output < n_categories).all()
395+
flat = output.reshape((prefix_size, math.prod(dists_shape)))
396+
counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat)
397+
assert (counts <= 1).all()
398+
399+
368400
def testBernoulliShape(self):
369401
key = self.make_key(0)
370402
with jax.numpy_rank_promotion('allow'):

0 commit comments

Comments
 (0)