Skip to content

Commit e0d450b

Browse files
Resolve numerical instability in entropy of GeometricLogits. (#1852)
1 parent b19a83d commit e0d450b

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

numpyro/distributions/discrete.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -936,8 +936,11 @@ def variance(self):
936936
return (1.0 / self.probs - 1.0) / self.probs
937937

938938
def entropy(self):
939-
nexp = jnp.exp(-self.logits)
940-
return nexp * self.logits + jnp.log1p(nexp) * (1 + nexp)
939+
logq = -jax.nn.softplus(self.logits)
940+
logp = -jax.nn.softplus(-self.logits)
941+
p = jax.scipy.special.expit(self.logits)
942+
p_clip = jnp.clip(p, min=jnp.finfo(p).tiny)
943+
return -(1 - p) * logq / p_clip - logp
941944

942945

943946
def Geometric(probs=None, logits=None, *, validate_args=None):

0 commit comments

Comments
 (0)