Skip to content

Commit 0607a92

Browse files
authored
Avoid infinite loop in vmapped Binomial with p=0 (#1462)
* Avoid infinite loop in vmapped Binomial with p=0 * > to >= to match with the other conditionals.
1 parent fe01f02 commit 0607a92

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

numpyro/distributions/util.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def stirling_approx_tail(k):
6262
)
6363

6464

65+
_binomial_mu_thresh = 10
66+
67+
6568
def _binomial_btrs(key, p, n):
6669
"""
6770
Based on the transformed rejection sampling algorithm (BTRS) from the
@@ -103,13 +106,19 @@ def accept_fn(k, u, v):
103106
k, key, u, v = val
104107
early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
105108
early_reject = (k < 0) | (k > n)
106-
return lax.cond(
109+
# when vmapped _binomial_dispatch will convert the cond condition into
110+
# a HLO select that will execute both branches. This is a workaround
111+
# that avoids the resulting infinite loop when p=0. This should also
112+
# improve performance in less catastrophic cases.
113+
cond_exclude_small_mu = p * n >= _binomial_mu_thresh
114+
cond_main = lax.cond(
107115
early_accept | early_reject,
108116
(),
109117
lambda _: ~early_accept,
110118
(k, u, v),
111119
lambda x: ~accept_fn(*x),
112120
)
121+
return cond_exclude_small_mu & cond_main
113122

114123
tr_params = _get_tr_params(n, p)
115124
ret = lax.while_loop(
@@ -129,7 +138,11 @@ def _binom_inv_body_fn(val):
129138

130139
def _binom_inv_cond_fn(val):
131140
i, _, geom_acc = val
132-
return geom_acc <= n
141+
# see the note on cond_exclude_small_mu in _binomial_btrs
142+
# this cond_exclude_large_mu is unnecessary for correctness but will
143+
# still improve performance.
144+
cond_exclude_large_mu = p * n < _binomial_mu_thresh
145+
return cond_exclude_large_mu & (geom_acc <= n)
133146

134147
log1_p = jnp.log1p(-p)
135148
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
@@ -142,7 +155,7 @@ def dispatch(key, p, n):
142155
pq = jnp.where(is_le_mid, p, 1 - p)
143156
mu = n * pq
144157
k = lax.cond(
145-
mu < 10,
158+
mu < _binomial_mu_thresh,
146159
(key, pq, n),
147160
lambda x: _binomial_inversion(*x),
148161
(key, pq, n),

test/test_distributions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,3 +2498,13 @@ def test_kl_dirichlet_dirichlet(shape):
24982498
x = p.sample(random.PRNGKey(0), (10_000,)).copy()
24992499
expected = jnp.mean((p.log_prob(x) - q.log_prob(x)), 0)
25002500
assert_allclose(actual, expected, rtol=0.05)
2501+
2502+
2503+
def test_vmapped_binomial_p0():
2504+
# test that vmapped binomial with p = 0 does not have an infinite loop
2505+
def sample_binomial_withp0(key):
2506+
n = 2 * (random.uniform(key) > 0.5)
2507+
_, key = random.split(key)
2508+
return dist.Binomial(total_count=n, probs=0).sample(key)
2509+
2510+
jax.vmap(sample_binomial_withp0)(random.split(random.PRNGKey(0), 1))

0 commit comments

Comments
 (0)