Skip to content

Commit b486a95

Browse files
author
jax authors
committed
Merge pull request #21507 from renecotyfanboy:main
PiperOrigin-RevId: 641429523
2 parents 6c822c0 + 751d59c commit b486a95

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

jax/_src/scipy/special.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,7 +2375,6 @@ def poch(z: ArrayLike, m: ArrayLike) -> Array:
23752375
Notes:
23762376
The JAX version supports only real-valued inputs.
23772377
"""
2378-
# Factorial definition when m is close to an integer, otherwise gamma definition.
23792378
z, m = promote_args_inexact("poch", z, m)
23802379

23812380
return jnp.where(m == 0., jnp.array(1, dtype=z.dtype), gamma(z + m) / gamma(z))
@@ -2412,6 +2411,8 @@ def _hyp1f1_serie(a, b, x):
24122411
https://doi.org/10.48550/arXiv.1407.7786
24132412
"""
24142413

2414+
precision = jnp.finfo(x.dtype).eps
2415+
24152416
def body(state):
24162417
serie, k, term = state
24172418
serie += term
@@ -2423,7 +2424,7 @@ def body(state):
24232424
def cond(state):
24242425
serie, k, term = state
24252426

2426-
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8)
2427+
return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision)
24272428

24282429
init = 1, 1, a / b * x
24292430

@@ -2437,6 +2438,8 @@ def _hyp1f1_asymptotic(a, b, x):
24372438
https://doi.org/10.48550/arXiv.1407.7786
24382439
"""
24392440

2441+
precision = jnp.finfo(x.dtype).eps
2442+
24402443
def body(state):
24412444
serie, k, term = state
24422445
serie += term
@@ -2448,7 +2451,7 @@ def body(state):
24482451
def cond(state):
24492452
serie, k, term = state
24502453

2451-
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8)
2454+
return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision)
24522455

24532456
init = 1, 1, (b - a) * (1 - a) / x
24542457
serie = lax.while_loop(cond, body, init)[0]
@@ -2464,6 +2467,8 @@ def _hyp1f1_a_derivative(a, b, x):
24642467
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/
24652468
"""
24662469

2470+
precision = jnp.finfo(x.dtype).eps
2471+
24672472
def body(state):
24682473
serie, k, term = state
24692474
serie += term * (digamma(a + k) - digamma(a))
@@ -2475,7 +2480,7 @@ def body(state):
24752480
def cond(state):
24762481
serie, k, term = state
24772482

2478-
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15)
2483+
return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision)
24792484

24802485
init = 0, 1, a / b * x
24812486

@@ -2490,6 +2495,8 @@ def _hyp1f1_b_derivative(a, b, x):
24902495
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/
24912496
"""
24922497

2498+
precision = jnp.finfo(x.dtype).eps
2499+
24932500
def body(state):
24942501
serie, k, term = state
24952502
serie += term * (digamma(b) - digamma(b + k))
@@ -2501,7 +2508,7 @@ def body(state):
25012508
def cond(state):
25022509
serie, k, term = state
25032510

2504-
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15)
2511+
return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision)
25052512

25062513
init = 0, 1, a / b * x
25072514

0 commit comments

Comments
 (0)