@@ -2375,7 +2375,6 @@ def poch(z: ArrayLike, m: ArrayLike) -> Array:
2375
2375
Notes:
2376
2376
The JAX version supports only real-valued inputs.
2377
2377
"""
2378
- # Factorial definition when m is close to an integer, otherwise gamma definition.
2379
2378
z , m = promote_args_inexact ("poch" , z , m )
2380
2379
2381
2380
return jnp .where (m == 0. , jnp .array (1 , dtype = z .dtype ), gamma (z + m ) / gamma (z ))
@@ -2412,6 +2411,8 @@ def _hyp1f1_serie(a, b, x):
2412
2411
https://doi.org/10.48550/arXiv.1407.7786
2413
2412
"""
2414
2413
2414
+ precision = jnp .finfo (x .dtype ).eps
2415
+
2415
2416
def body (state ):
2416
2417
serie , k , term = state
2417
2418
serie += term
@@ -2423,7 +2424,7 @@ def body(state):
2423
2424
def cond (state ):
2424
2425
serie , k , term = state
2425
2426
2426
- return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > 1e-8 )
2427
+ return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > precision )
2427
2428
2428
2429
init = 1 , 1 , a / b * x
2429
2430
@@ -2437,6 +2438,8 @@ def _hyp1f1_asymptotic(a, b, x):
2437
2438
https://doi.org/10.48550/arXiv.1407.7786
2438
2439
"""
2439
2440
2441
+ precision = jnp .finfo (x .dtype ).eps
2442
+
2440
2443
def body (state ):
2441
2444
serie , k , term = state
2442
2445
serie += term
@@ -2448,7 +2451,7 @@ def body(state):
2448
2451
def cond (state ):
2449
2452
serie , k , term = state
2450
2453
2451
- return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > 1e-8 )
2454
+ return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > precision )
2452
2455
2453
2456
init = 1 , 1 , (b - a ) * (1 - a ) / x
2454
2457
serie = lax .while_loop (cond , body , init )[0 ]
@@ -2464,6 +2467,8 @@ def _hyp1f1_a_derivative(a, b, x):
2464
2467
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/
2465
2468
"""
2466
2469
2470
+ precision = jnp .finfo (x .dtype ).eps
2471
+
2467
2472
def body (state ):
2468
2473
serie , k , term = state
2469
2474
serie += term * (digamma (a + k ) - digamma (a ))
@@ -2475,7 +2480,7 @@ def body(state):
2475
2480
def cond (state ):
2476
2481
serie , k , term = state
2477
2482
2478
- return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > 1e-15 )
2483
+ return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > precision )
2479
2484
2480
2485
init = 0 , 1 , a / b * x
2481
2486
@@ -2490,6 +2495,8 @@ def _hyp1f1_b_derivative(a, b, x):
2490
2495
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/
2491
2496
"""
2492
2497
2498
+ precision = jnp .finfo (x .dtype ).eps
2499
+
2493
2500
def body (state ):
2494
2501
serie , k , term = state
2495
2502
serie += term * (digamma (b ) - digamma (b + k ))
@@ -2501,7 +2508,7 @@ def body(state):
2501
2508
def cond (state ):
2502
2509
serie , k , term = state
2503
2510
2504
- return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > 1e-15 )
2511
+ return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > precision )
2505
2512
2506
2513
init = 0 , 1 , a / b * x
2507
2514
0 commit comments