@@ -62,6 +62,9 @@ def stirling_approx_tail(k):
6262 )
6363
6464
65+ _binomial_mu_thresh = 10
66+
67+
6568def _binomial_btrs (key , p , n ):
6669 """
6770 Based on the transformed rejection sampling algorithm (BTRS) from the
@@ -103,13 +106,19 @@ def accept_fn(k, u, v):
103106 k , key , u , v = val
104107 early_accept = (jnp .abs (u ) <= tr_params .u_r ) & (v <= tr_params .v_r )
105108 early_reject = (k < 0 ) | (k > n )
106- return lax .cond (
109+ # when vmapped _binomial_dispatch will convert the cond condition into
110+ # a HLO select that will execute both branches. This is a workaround
111+ # that avoids the resulting infinite loop when p=0. This should also
112+ # improve performance in less catastrophic cases.
113+ cond_exclude_small_mu = p * n >= _binomial_mu_thresh
114+ cond_main = lax .cond (
107115 early_accept | early_reject ,
108116 (),
109117 lambda _ : ~ early_accept ,
110118 (k , u , v ),
111119 lambda x : ~ accept_fn (* x ),
112120 )
121+ return cond_exclude_small_mu & cond_main
113122
114123 tr_params = _get_tr_params (n , p )
115124 ret = lax .while_loop (
@@ -129,7 +138,11 @@ def _binom_inv_body_fn(val):
129138
130139 def _binom_inv_cond_fn (val ):
131140 i , _ , geom_acc = val
132- return geom_acc <= n
141+ # see the note on cond_exclude_small_mu in _binomial_btrs
142+ # this cond_exclude_large_mu is unnecessary for correctness but will
143+ # still improve performance.
144+ cond_exclude_large_mu = p * n < _binomial_mu_thresh
145+ return cond_exclude_large_mu & (geom_acc <= n )
133146
134147 log1_p = jnp .log1p (- p )
135148 ret = lax .while_loop (_binom_inv_cond_fn , _binom_inv_body_fn , (- 1 , key , 0.0 ))
@@ -142,7 +155,7 @@ def dispatch(key, p, n):
142155 pq = jnp .where (is_le_mid , p , 1 - p )
143156 mu = n * pq
144157 k = lax .cond (
145- mu < 10 ,
158+ mu < _binomial_mu_thresh ,
146159 (key , pq , n ),
147160 lambda x : _binomial_inversion (* x ),
148161 (key , pq , n ),
0 commit comments