Skip to content

Commit 884aec9

Browse files
authored
BUG: Allow np.percentile to operate on float16 data (numpy#29105)
* BUG: Allow np.percentile to operate on float16 data * add an extra regression test * add an extra regression test * remove unused default value * add release note * review comments: part1 * review comments: part 2 * review comments: part 3
1 parent f39abd4 commit 884aec9

File tree

3 files changed

+72
-20
lines changed

3 files changed

+72
-20
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* The accuracy of ``np.quantile`` and ``np.percentile`` for 16- and 32-bit floating point input data has been improved.

numpy/lib/_function_base_impl.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4219,9 +4219,7 @@ def percentile(a,
42194219
if a.dtype.kind == "c":
42204220
raise TypeError("a must be an array of real numbers")
42214221

4222-
# Use dtype of array if possible (e.g., if q is a python int or float)
4223-
# by making the divisor have the dtype of the data array.
4224-
q = np.true_divide(q, a.dtype.type(100) if a.dtype.kind == "f" else 100, out=...)
4222+
q = np.true_divide(q, 100, out=...)
42254223
if not _quantile_is_valid(q):
42264224
raise ValueError("Percentiles must be in the range [0, 100]")
42274225

@@ -4469,11 +4467,7 @@ def quantile(a,
44694467
if a.dtype.kind == "c":
44704468
raise TypeError("a must be an array of real numbers")
44714469

4472-
# Use dtype of array if possible (e.g., if q is a python int or float).
4473-
if isinstance(q, (int, float)) and a.dtype.kind == "f":
4474-
q = np.asanyarray(q, dtype=a.dtype)
4475-
else:
4476-
q = np.asanyarray(q)
4470+
q = np.asanyarray(q)
44774471

44784472
if not _quantile_is_valid(q):
44794473
raise ValueError("Quantiles must be in the range [0, 1]")
@@ -4549,7 +4543,7 @@ def _compute_virtual_index(n, quantiles, alpha: float, beta: float):
45494543
) - 1
45504544

45514545

4552-
def _get_gamma(virtual_indexes, previous_indexes, method):
4546+
def _get_gamma(virtual_indexes, previous_indexes, method, dtype):
45534547
"""
45544548
Compute gamma (a.k.a 'm' or 'weight') for the linear interpolation
45554549
of quantiles.
@@ -4570,7 +4564,7 @@ def _get_gamma(virtual_indexes, previous_indexes, method):
45704564
gamma = method["fix_gamma"](gamma, virtual_indexes)
45714565
# Ensure both that we have an array, and that we keep the dtype
45724566
# (which may have been matched to the input array).
4573-
return np.asanyarray(gamma, dtype=virtual_indexes.dtype)
4567+
return np.asanyarray(gamma, dtype=dtype)
45744568

45754569

45764570
def _lerp(a, b, t, out=None):
@@ -4788,7 +4782,16 @@ def _quantile(
47884782
previous = arr[previous_indexes]
47894783
next = arr[next_indexes]
47904784
# --- Linear interpolation
4791-
gamma = _get_gamma(virtual_indexes, previous_indexes, method_props)
4785+
if arr.dtype.kind in "iu":
4786+
gtype = None
4787+
elif arr.dtype.kind == "f":
4788+
# make sure the return value matches the input array type
4789+
gtype = arr.dtype
4790+
else:
4791+
gtype = virtual_indexes.dtype
4792+
4793+
gamma = _get_gamma(virtual_indexes, previous_indexes,
4794+
method_props, gtype)
47924795
result_shape = virtual_indexes.shape + (1,) * (arr.ndim - 1)
47934796
gamma = gamma.reshape(result_shape)
47944797
result = _lerp(previous,

numpy/lib/tests/test_function_base.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3277,6 +3277,16 @@ def test_period(self):
32773277
assert_almost_equal(np.interp(x, xp, fp, period=360), y)
32783278

32793279

3280+
quantile_methods = [
3281+
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
3282+
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear',
3283+
'median_unbiased', 'normal_unbiased', 'nearest', 'lower', 'higher',
3284+
'midpoint']
3285+
3286+
3287+
methods_supporting_weights = ["inverted_cdf"]
3288+
3289+
32803290
class TestPercentile:
32813291

32823292
def test_basic(self):
@@ -3870,15 +3880,38 @@ def test_nat_basic(self, dtype, pos):
38703880
res = np.percentile(a, 30, axis=0)
38713881
assert_array_equal(np.isnat(res), [False, True, False])
38723882

3873-
3874-
quantile_methods = [
3875-
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
3876-
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear',
3877-
'median_unbiased', 'normal_unbiased', 'nearest', 'lower', 'higher',
3878-
'midpoint']
3879-
3880-
3881-
methods_supporting_weights = ["inverted_cdf"]
3883+
@pytest.mark.parametrize("qtype", [np.float16, np.float32])
3884+
@pytest.mark.parametrize("method", quantile_methods)
3885+
def test_percentile_gh_29003(self, qtype, method):
3886+
# test that with float16 or float32 input we do not get overflow
3887+
zero = qtype(0)
3888+
one = qtype(1)
3889+
a = np.zeros(65521, qtype)
3890+
a[:20_000] = one
3891+
z = np.percentile(a, 50, method=method)
3892+
assert z == zero
3893+
assert z.dtype == a.dtype
3894+
z = np.percentile(a, 99, method=method)
3895+
assert z == one
3896+
assert z.dtype == a.dtype
3897+
3898+
def test_percentile_gh_29003_Fraction(self):
3899+
zero = Fraction(0)
3900+
one = Fraction(1)
3901+
a = np.array([zero] * 65521)
3902+
a[:20_000] = one
3903+
z = np.percentile(a, 50)
3904+
assert z == zero
3905+
z = np.percentile(a, Fraction(50))
3906+
assert z == zero
3907+
assert np.array(z).dtype == a.dtype
3908+
3909+
z = np.percentile(a, 99)
3910+
assert z == one
3911+
# test that with only Fraction input the return type is a Fraction
3912+
z = np.percentile(a, Fraction(99))
3913+
assert z == one
3914+
assert np.array(z).dtype == a.dtype
38823915

38833916

38843917
class TestQuantile:
@@ -4244,6 +4277,21 @@ def test_closest_observation(self):
42444277
assert_equal(4, np.quantile(arr[0:9], q, method=m))
42454278
assert_equal(5, np.quantile(arr, q, method=m))
42464279

4280+
def test_quantile_gh_29003_Fraction(self):
4281+
r = np.quantile([1, 2], q=Fraction(1))
4282+
assert r == Fraction(2)
4283+
assert isinstance(r, Fraction)
4284+
4285+
r = np.quantile([1, 2], q=Fraction(.5))
4286+
assert r == Fraction(3, 2)
4287+
assert isinstance(r, Fraction)
4288+
4289+
def test_float16_gh_29003(self):
4290+
a = np.arange(50_001, dtype=np.float16)
4291+
q = .999
4292+
value = np.quantile(a, q)
4293+
assert value == q * 50_000
4294+
assert value.dtype == np.float16
42474295

42484296
class TestLerp:
42494297
@hypothesis.given(t0=st.floats(allow_nan=False, allow_infinity=False,

0 commit comments

Comments
 (0)