Skip to content

Commit 0791b1f

Browse files
authored
Refactor log_prob method in _MixtureBase class to handle negative infinity values in sum_log_probs (#1874)
1 parent 8e9313f commit 0791b1f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

numpyro/distributions/mixtures.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,10 @@ def sample(self, key, sample_shape=()):
149149
def log_prob(self, value, intermediates=None):
150150
del intermediates
151151
sum_log_probs = self.component_log_probs(value)
152-
return jax.nn.logsumexp(sum_log_probs, axis=-1)
152+
safe_sum_log_probs = jnp.where(
153+
jnp.isneginf(sum_log_probs), -jnp.inf, sum_log_probs
154+
)
155+
return jax.nn.logsumexp(safe_sum_log_probs, axis=-1)
153156

154157

155158
class MixtureSameFamily(_MixtureBase):

0 commit comments

Comments
 (0)