@@ -1548,12 +1548,18 @@ def _gumbel(key, shape, dtype, mode) -> Array:
15481548 _uniform (key , shape , dtype , minval = jnp .finfo (dtype ).tiny , maxval = 1. )))
15491549
15501550
1551- def categorical (key : ArrayLike ,
1552- logits : RealArray ,
1553- axis : int = - 1 ,
1554- shape : Shape | None = None ) -> Array :
1551+ def categorical (
1552+ key : ArrayLike ,
1553+ logits : RealArray ,
1554+ axis : int = - 1 ,
1555+ shape : Shape | None = None ,
1556+ replace : bool = True ,
1557+ ) -> Array :
15551558 """Sample random values from categorical distributions.
15561559
1560+ Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses
1561+ the Gumbel top-k trick. See [1] for reference.
1562+
15571563 Args:
15581564 key: a PRNG key used as the random key.
15591565 logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
@@ -1562,32 +1568,57 @@ def categorical(key: ArrayLike,
15621568 shape: Optional, a tuple of nonnegative integers representing the result shape.
15631569 Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
15641570 The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
1571+ replace: If True, perform sampling without replacement. Default (False) is to
1572+ perform sampling with replacement.
15651573
15661574 Returns:
15671575 A random array with int dtype and shape given by ``shape`` if ``shape``
15681576 is not None, or else ``np.delete(logits.shape, axis)``.
1577+
1578+ References:
1579+ .. [1] Wouter Kool, Herke van Hoof, Max Welling. "Stochastic Beams and Where to Find
1580+ Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement".
1581+ Proceedings of the 36th International Conference on Machine Learning, PMLR
1582+ 97:3499-3508, 2019. https://proceedings.mlr.press/v97/kool19a.html.
15691583 """
15701584 key , _ = _check_prng_key ("categorical" , key )
15711585 check_arraylike ("categorical" , logits )
15721586 logits_arr = jnp .asarray (logits )
1573-
1574- if axis >= 0 :
1575- axis -= len (logits_arr .shape )
1576-
15771587 batch_shape = tuple (np .delete (logits_arr .shape , axis ))
15781588 if shape is None :
15791589 shape = batch_shape
15801590 else :
15811591 shape = core .canonicalize_shape (shape )
15821592 _check_shape ("categorical" , shape , batch_shape )
1583-
15841593 shape_prefix = shape [:len (shape )- len (batch_shape )]
1585- logits_shape = list (shape [len (shape ) - len (batch_shape ):])
1586- logits_shape .insert (axis % len (logits_arr .shape ), logits_arr .shape [axis ])
1587- return jnp .argmax (
1588- gumbel (key , (* shape_prefix , * logits_shape ), logits_arr .dtype ) +
1589- lax .expand_dims (logits_arr , tuple (range (len (shape_prefix )))),
1590- axis = axis )
1594+
1595+ if replace :
1596+ if axis >= 0 :
1597+ axis -= len (logits_arr .shape )
1598+
1599+ logits_shape = list (shape [len (shape ) - len (batch_shape ):])
1600+ logits_shape .insert (axis % len (logits_arr .shape ), logits_arr .shape [axis ])
1601+ return jnp .argmax (
1602+ gumbel (key , (* shape_prefix , * logits_shape ), logits_arr .dtype ) +
1603+ lax .expand_dims (logits_arr , tuple (range (len (shape_prefix )))),
1604+ axis = axis )
1605+ else :
1606+ logits_arr += gumbel (key , logits_arr .shape , logits_arr .dtype )
1607+ k = math .prod (shape_prefix )
1608+ if k > logits_arr .shape [axis ]:
1609+ raise ValueError (
1610+ f"Number of samples without replacement ({ k } ) cannot exceed number of "
1611+ f"categories ({ logits_arr .shape [axis ]} )."
1612+ )
1613+
1614+ _ , indices = lax .top_k (jnp .moveaxis (logits_arr , axis , - 1 ), k )
1615+ assert indices .shape == batch_shape + (k ,)
1616+ assert shape == shape_prefix + batch_shape
1617+
1618+ dimensions = (indices .ndim - 1 , * range (indices .ndim - 1 ))
1619+ indices = lax .reshape (indices , shape , dimensions )
1620+ assert indices .shape == shape
1621+ return indices
15911622
15921623
15931624def laplace (key : ArrayLike ,
0 commit comments