Skip to content

Commit f5aca91

Browse files
authored
Fix binominal distribution (#1860)
* Comments added why the binominal_dispatch function will run into an infinite loop * Fix for the infinite loop problem of binominal_dispatch. * Update of the fix to a more concise version. * Linting fix * Changed to the proposed solution, to correct log1_p value correctly
1 parent bb7767e commit f5aca91

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

numpyro/distributions/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def _binom_inv_cond_fn(val):
149149
return cond_exclude_large_mu & (geom_acc <= n)
150150

151151
log1_p = jnp.log1p(-p)
152+
# Make sure p=0 is never taken into account as a fix for possible zeros in p.
153+
log1_p = jnp.where(log1_p == 0, -jnp.finfo(log1_p.dtype).tiny, log1_p)
152154
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
153155
return ret[0]
154156

0 commit comments

Comments
 (0)