@@ -2109,11 +2109,12 @@ def assert_really_equal(x, y, rtol=None):
2109
2109
Sharper assertion function that is stricter about matching types, not just values
2110
2110
2111
2111
This is useful/necessary in some cases:
2112
- * handled by xp_assert_* functions
2113
2112
* dtypes for arrays that have the same _values_ (e.g. element 1.0 vs 1)
2114
2113
* distinguishing complex from real NaN
2114
+ * result types for scalars
2115
2115
2116
2116
We still want to be able to allow a relative tolerance for the values though.
2117
+ The main logic comparison logic is handled by the xp_assert_* functions.
2117
2118
"""
2118
2119
def assert_func (x , y ):
2119
2120
xp_assert_equal (x , y ) if rtol is None else xp_assert_close (x , y , rtol = rtol )
@@ -2350,6 +2351,24 @@ def _nest_me(x, k=1):
2350
2351
assert_func (special .factorialk (n , 3 , exact = exact ),
2351
2352
np .array (exp_nucleus [3 ], ndmin = level ))
2352
2353
2354
+ @pytest .mark .parametrize ("dtype" , [np .uint8 , np .uint16 , np .uint32 , np .uint64 ])
2355
+ @pytest .mark .parametrize ("exact,extend" ,
2356
+ [(True , "zero" ), (False , "zero" ), (False , "complex" )])
2357
+ def test_factorialx_uint (self , exact , extend , dtype ):
2358
+ # ensure that uint types work correctly as inputs
2359
+ kw = {"exact" : exact , "extend" : extend }
2360
+ assert_func = assert_array_equal if exact else assert_allclose
2361
+ def _check (n ):
2362
+ n_ref = n .astype (np .int64 ) if isinstance (n , np .ndarray ) else np .int64 (n )
2363
+ assert_func (special .factorial (n , ** kw ), special .factorial (n_ref , ** kw ))
2364
+ assert_func (special .factorial2 (n , ** kw ), special .factorial2 (n_ref , ** kw ))
2365
+ assert_func (special .factorialk (n , k = 3 , ** kw ),
2366
+ special .factorialk (n_ref , k = 3 , ** kw ))
2367
+ _check (dtype (0 ))
2368
+ _check (dtype (1 ))
2369
+ _check (np .array (0 , dtype = dtype ))
2370
+ _check (np .array ([0 , 1 ], dtype = dtype ))
2371
+
2353
2372
# note that n=170 is the last integer such that factorial(n) fits float64
2354
2373
@pytest .mark .parametrize ('n' , range (30 , 180 , 10 ))
2355
2374
def test_factorial_accuracy (self , n ):
0 commit comments