Skip to content

Commit ea3f90f

Browse files
authored
BUG: fix incorrect values in factorial for 0 with uint dtype (scipy#22168)
1 parent 7f53e6e commit ea3f90f

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

scipy/special/_basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2972,8 +2972,9 @@ def _factorialx_approx_core(n, k, extend):
29722972
with warnings.catch_warnings():
29732973
# do not warn about 0 * inf, nan / nan etc.; the results are correct
29742974
warnings.simplefilter("ignore", RuntimeWarning)
2975-
result = np.power(k, (n - 1) / k, dtype=p_dtype) * _gamma1p(n / k)
2976-
result *= rgamma(1 / k + 1)
2975+
# don't use `(n-1)/k` in np.power; underflows if 0 is of a uintX type
2976+
result = np.power(k, n / k, dtype=p_dtype) * _gamma1p(n / k)
2977+
result *= rgamma(1 / k + 1) / np.power(k, 1 / k, dtype=p_dtype)
29772978
if isinstance(n, np.ndarray):
29782979
# ensure we keep array-ness for 0-dim inputs; already n/k above loses it
29792980
result = np.array(result)

scipy/special/tests/test_basic.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2109,11 +2109,12 @@ def assert_really_equal(x, y, rtol=None):
21092109
Sharper assertion function that is stricter about matching types, not just values
21102110
21112111
This is useful/necessary in some cases:
2112-
* handled by xp_assert_* functions
21132112
* dtypes for arrays that have the same _values_ (e.g. element 1.0 vs 1)
21142113
* distinguishing complex from real NaN
2114+
* result types for scalars
21152115
21162116
We still want to be able to allow a relative tolerance for the values though.
2117+
The main logic comparison logic is handled by the xp_assert_* functions.
21172118
"""
21182119
def assert_func(x, y):
21192120
xp_assert_equal(x, y) if rtol is None else xp_assert_close(x, y, rtol=rtol)
@@ -2350,6 +2351,24 @@ def _nest_me(x, k=1):
23502351
assert_func(special.factorialk(n, 3, exact=exact),
23512352
np.array(exp_nucleus[3], ndmin=level))
23522353

2354+
@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.uint32, np.uint64])
2355+
@pytest.mark.parametrize("exact,extend",
2356+
[(True, "zero"), (False, "zero"), (False, "complex")])
2357+
def test_factorialx_uint(self, exact, extend, dtype):
2358+
# ensure that uint types work correctly as inputs
2359+
kw = {"exact": exact, "extend": extend}
2360+
assert_func = assert_array_equal if exact else assert_allclose
2361+
def _check(n):
2362+
n_ref = n.astype(np.int64) if isinstance(n, np.ndarray) else np.int64(n)
2363+
assert_func(special.factorial(n, **kw), special.factorial(n_ref, **kw))
2364+
assert_func(special.factorial2(n, **kw), special.factorial2(n_ref, **kw))
2365+
assert_func(special.factorialk(n, k=3, **kw),
2366+
special.factorialk(n_ref, k=3, **kw))
2367+
_check(dtype(0))
2368+
_check(dtype(1))
2369+
_check(np.array(0, dtype=dtype))
2370+
_check(np.array([0, 1], dtype=dtype))
2371+
23532372
# note that n=170 is the last integer such that factorial(n) fits float64
23542373
@pytest.mark.parametrize('n', range(30, 180, 10))
23552374
def test_factorial_accuracy(self, n):

0 commit comments

Comments
 (0)