We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
GeometricLogits
1 parent b19a83d commit e0d450bCopy full SHA for e0d450b
numpyro/distributions/discrete.py
@@ -936,8 +936,11 @@ def variance(self):
936
return (1.0 / self.probs - 1.0) / self.probs
937
938
def entropy(self):
939
- nexp = jnp.exp(-self.logits)
940
- return nexp * self.logits + jnp.log1p(nexp) * (1 + nexp)
+ logq = -jax.nn.softplus(self.logits)
+ logp = -jax.nn.softplus(-self.logits)
941
+ p = jax.scipy.special.expit(self.logits)
942
+ p_clip = jnp.clip(p, min=jnp.finfo(p).tiny)
943
+ return -(1 - p) * logq / p_clip - logp
944
945
946
def Geometric(probs=None, logits=None, *, validate_args=None):
0 commit comments