Skip to content

Commit 96c0129

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Fix false positive debug_nans error caused by NaNs that are properly handled in jax.scipy.stats.gamma
As reported in jax-ml#24939, even though the implementation of `jax.scipy.stats.gamma.logpdf` handles invalid inputs (e.g. `x < loc`) by returning `-inf`, the existing implementation incorrectly triggers the NaN checks introduced by JAX's debug NaNs mode. This change updates the implementation to no longer produce internal NaNs. Fixes jax-ml#24939 PiperOrigin-RevId: 698833589
1 parent e707ede commit 96c0129

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

jax/_src/scipy/stats/gamma.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
5151
- :func:`jax.scipy.stats.gamma.logsf`
5252
"""
5353
x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale)
54+
ok = lax.ge(x, loc)
5455
one = _lax_const(x, 1)
55-
y = lax.div(lax.sub(x, loc), scale)
56+
y = jnp.where(ok, lax.div(lax.sub(x, loc), scale), one)
5657
log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
5758
shape_terms = lax.add(gammaln(a), lax.log(scale))
5859
log_probs = lax.sub(log_linear_term, shape_terms)
59-
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
60+
return jnp.where(ok, log_probs, -jnp.inf)
6061

6162

6263
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:

tests/scipy_stats_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,13 @@ def testGammaLogPdfZero(self):
543543
self.assertAllClose(
544544
osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)
545545

546+
def testGammaDebugNans(self):
547+
# Regression test for https://github.com/jax-ml/jax/issues/24939
548+
with jax.debug_nans(True):
549+
self.assertAllClose(
550+
osp_stats.gamma.pdf(0.0, 1.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0, 1.0)
551+
)
552+
546553
@genNamedParametersNArgs(4)
547554
def testGammaLogCdf(self, shapes, dtypes):
548555
rng = jtu.rand_positive(self.rng())

0 commit comments

Comments
 (0)