Skip to content

Commit ab42a3e

Browse files
committed
Fix betainc edge cases and inaccuracies when a is close to zero.
1 parent 47dde87 commit ab42a3e

File tree

3 files changed

+84
-19
lines changed

3 files changed

+84
-19
lines changed

jax/_src/lax/special.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,18 @@ def nth_partial_betainc_numerator(iteration, a, b, x):
194194
iteration_is_one = eq(iteration_bcast, full_like(iteration_bcast, 1))
195195
iteration_minus_one = iteration_bcast - full_like(iteration_bcast, 1)
196196
m = iteration_minus_one // full_like(iteration_minus_one, 2)
197+
m_is_zero = eq(m, full_like(m, 0))
197198
m = convert_element_type(m, dtype)
198199
one = full_like(a, 1)
199200
two = full_like(a, 2.0)
200201
# Partial numerator terms
201-
even_numerator = -(a + m) * (a + b + m) * x / (
202-
(a + two * m) * (a + two * m + one))
202+
203+
# When a is close to zero and m == 0, using zero_numerator avoids
204+
# inaccuracies when FTZ or DAZ is enabled:
205+
zero_numerator = -(a + b) * x / (a + one)
206+
even_numerator = select(m_is_zero, zero_numerator,
207+
-(a + m) * (a + b + m) * x / (
208+
(a + two * m) * (a + two * m + one)))
203209
odd_numerator = m * (b - m) * x / ((a + two * m - one) * (a + two * m))
204210
one_numerator = full_like(x, 1.0)
205211
numerator = select(iteration_is_even, even_numerator, odd_numerator)
@@ -210,12 +216,24 @@ def nth_partial_betainc_denominator(iteration, a, b, x):
210216
return select(eq(iteration_bcast, full_like(iteration_bcast, 0)),
211217
full_like(x, 0), full_like(x, 1))
212218

219+
a_is_zero = bitwise_or(eq(a, full_like(a, 0)), eq(b, full_like(b, float('inf'))))
220+
b_is_zero = bitwise_or(eq(b, full_like(b, 0)), eq(a, full_like(a, float('inf'))))
221+
x_is_zero = eq(x, full_like(x, 0))
222+
x_is_one = eq(x, full_like(x, 1))
223+
x_is_not_zero = bitwise_not(x_is_zero)
224+
x_is_not_one = bitwise_not(x_is_one)
225+
is_nan = bitwise_or(bitwise_or(_isnan(a), _isnan(b)), _isnan(x))
226+
227+
result_is_zero = bitwise_or(bitwise_and(b_is_zero, x_is_not_one), bitwise_and(a_is_zero, x_is_zero))
228+
result_is_one = bitwise_or(bitwise_and(a_is_zero, x_is_not_zero), bitwise_and(b_is_zero, x_is_one))
229+
213230
result_is_nan = bitwise_or(bitwise_or(bitwise_or(
214-
le(a, full_like(a, 0)), le(b, full_like(b, 0))),
231+
lt(a, full_like(a, 0)), lt(b, full_like(b, 0))),
215232
lt(x, full_like(x, 0))), gt(x, full_like(x, 1)))
233+
result_is_nan = bitwise_or(result_is_nan, bitwise_or(bitwise_and(a_is_zero, b_is_zero), is_nan))
216234

217-
# The continued fraction will converge rapidly when x < (a+1)/(a+b+2)
218-
# as per: http://dlmf.nist.gov/8.17.E23
235+
# The continued fraction will converge rapidly when x <
236+
# (a+1)/(a+b+2) as per: http://dlmf.nist.gov/8.17.E23.
219237
#
220238
# Otherwise, we can rewrite using the symmetry relation as per:
221239
# http://dlmf.nist.gov/8.17.E4
@@ -234,10 +252,21 @@ def nth_partial_betainc_denominator(iteration, a, b, x):
234252
inputs=[a, b, x]
235253
)
236254

237-
lbeta_ab = lgamma(a) + lgamma(b) - lgamma(a + b)
238-
result = continued_fraction * exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a
255+
# For very small a and to avoid division by zero, we'll use
256+
# a * gamma(a) = gamma(a + 1) -> 1 as a -> 0+.
257+
very_small = (dtypes.finfo(dtype).tiny * 2).astype(dtype)
258+
lbeta_ab_small_a = lgamma(b) - lgamma(a + b)
259+
lbeta_ab = lgamma(a) + lbeta_ab_small_a
260+
factor = select(lt(a, full_like(a, very_small)),
261+
exp(log1p(-x) * b - lbeta_ab_small_a),
262+
exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a)
263+
result = continued_fraction * factor
264+
result = select(converges_rapidly, result, sub(full_like(result, 1), result))
265+
266+
result = select(result_is_zero, full_like(a, 0), result)
267+
result = select(result_is_one, full_like(a, 1), result)
239268
result = select(result_is_nan, full_like(a, float('nan')), result)
240-
return select(converges_rapidly, result, sub(full_like(result, 1), result))
269+
return result
241270

242271
class IgammaMode(Enum):
243272
VALUE = 1

jax/_src/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1522,7 +1522,7 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
15221522
args = args_maker()
15231523
lax_ans = lax_op(*args)
15241524
numpy_ans = numpy_reference_op(*args)
1525-
self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes,
1525+
self.assertAllClose(lax_ans, numpy_ans, check_dtypes=check_dtypes,
15261526
atol=atol or tol, rtol=rtol or tol,
15271527
canonicalize_dtypes=canonicalize_dtypes)
15281528

tests/lax_scipy_special_functions_test.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -288,35 +288,71 @@ def testExpiDisableJit(self):
288288
self.assertAllClose(result_jit, result_nojit)
289289

290290
def testGammaIncBoundaryValues(self):
291-
dtype = jax.numpy.zeros(0).dtype # default float dtype.
291+
dtype = jax.dtypes.canonicalize_dtype(float)
292292
nan = float('nan')
293293
inf = float('inf')
294294
if jtu.parse_version(scipy.__version__) >= (1, 16):
295-
samples_slice = slice(None)
295+
a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan]
296+
x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf]
296297
else:
297298
# disable samples that contradict with scipy/scipy#22441
298-
samples_slice = slice(None, -1)
299-
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan][samples_slice]).astype(dtype),
300-
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf][samples_slice]).astype(dtype)]
299+
a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1]
300+
x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1]
301+
302+
args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype))
303+
301304
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
302305
self._CheckAgainstNumpy(lsp_special.gammainc, osp_special.gammainc, args_maker, rtol=rtol)
303306
self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol)
304307

305308
def testGammaIncCBoundaryValues(self):
306-
dtype = jax.numpy.zeros(0).dtype # default float dtype.
309+
dtype = jax.dtypes.canonicalize_dtype(float)
307310
nan = float('nan')
308311
inf = float('inf')
309312
if jtu.parse_version(scipy.__version__) >= (1, 16):
310-
samples_slice = slice(None)
313+
a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan]
314+
x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf]
311315
else:
312316
# disable samples that contradict with scipy/scipy#22441
313-
samples_slice = slice(None, -1)
314-
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan][samples_slice]).astype(dtype),
315-
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf][samples_slice]).astype(dtype)]
317+
a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1]
318+
x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1]
319+
320+
args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype))
321+
316322
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
317323
self._CheckAgainstNumpy(lsp_special.gammaincc, osp_special.gammaincc, args_maker, rtol=rtol)
318324
self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol)
319325

326+
def testBetaIncBoundaryValues(self):
327+
dtype = jax.dtypes.canonicalize_dtype(float)
328+
fi = jax.numpy.finfo(dtype)
329+
nan = float('nan')
330+
inf = float('inf')
331+
tiny = fi.tiny
332+
eps = fi.eps
333+
if jtu.parse_version(scipy.__version__) >= (1, 16):
334+
# TODO(pearu): enable tiny samples when a fix to scipy/scipy#22682
335+
# will be available
336+
a_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1]
337+
b_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1]
338+
elif jtu.parse_version(scipy.__version__) >= (1, 12):
339+
# disabled samples that contradict with scipy/scipy#22425
340+
a_samples = [nan, -0.5, 0.5]
341+
b_samples = [nan, -0.5, 0.5]
342+
else:
343+
a_samples = [-0.5, 0.5]
344+
b_samples = [-0.5, 0.5]
345+
x_samples = [nan, -0.5, 0, 0.5, 1, 1.5]
346+
347+
a_samples = np.array(a_samples, dtype=dtype)
348+
b_samples = np.array(b_samples, dtype=dtype)
349+
x_samples = np.array(x_samples, dtype=dtype)
350+
351+
args_maker = lambda: np.meshgrid(a_samples, b_samples, x_samples)
352+
353+
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5
354+
self._CheckAgainstNumpy(osp_special.betainc, lsp_special.betainc, args_maker, rtol=rtol)
355+
self._CompileAndCheck(lsp_special.betainc, args_maker, rtol=rtol)
320356

321357
if __name__ == "__main__":
322358
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)