Skip to content

Commit f7cac01

Browse files
committed
Add delegation for quantile
This makes quantile available when the version of Scipy is not new enough to support array API inputs.
1 parent 453edda commit f7cac01

File tree

5 files changed

+304
-2
lines changed

5 files changed

+304
-2
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
nunique
2020
one_hot
2121
pad
22+
quantile
2223
setdiff1d
2324
sinc
2425
```

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, one_hot, pad
3+
from ._delegation import isclose, one_hot, pad, quantile
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -36,6 +36,7 @@
3636
"nunique",
3737
"one_hot",
3838
"pad",
39+
"quantile",
3940
"setdiff1d",
4041
"sinc",
4142
]

src/array_api_extra/_delegation.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ._lib._utils._helpers import asarrays
1919
from ._lib._utils._typing import Array, DType
2020

21-
__all__ = ["isclose", "one_hot", "pad"]
21+
__all__ = ["isclose", "one_hot", "pad", "quantile"]
2222

2323

2424
def isclose(
@@ -247,3 +247,70 @@ def pad(
247247
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
248248

249249
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
250+
251+
252+
def quantile(
253+
x: Array,
254+
q: Array | float,
255+
/,
256+
*,
257+
axis: int | None = None,
258+
keepdims: bool = False,
259+
method: str = "linear",
260+
xp: ModuleType | None = None,
261+
) -> Array:
262+
"""
263+
Compute the q-th quantile(s) of the data along the specified axis.
264+
265+
Parameters
266+
----------
267+
x : array of real numbers
268+
Data array.
269+
q : array of float
270+
Probability or sequence of probabilities of the quantiles to compute.
271+
Values must be between 0 and 1 (inclusive). Must have length 1 along
272+
`axis` unless ``keepdims=True``.
273+
method : str, default: 'linear'
274+
The method to use for estimating the quantile. The available options are:
275+
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
276+
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default),
277+
'median_unbiased', 'normal_unbiased', 'harrell-davis'.
278+
axis : int or None, default: None
279+
Axis along which the quantiles are computed. ``None`` ravels both `x`
280+
and `q` before performing the calculation.
281+
keepdims : bool, optional
282+
If this is set to True, the axes which are reduced are left in the
283+
result as dimensions with size one. With this option, the result will
284+
broadcast correctly against the original array `x`.
285+
xp : array_namespace, optional
286+
The standard-compatible namespace for `x` and `q`. Default: infer.
287+
288+
Returns
289+
-------
290+
array
291+
An array with the quantiles of the data.
292+
293+
Examples
294+
--------
295+
>>> import array_api_strict as xp
296+
>>> import array_api_extra as xpx
297+
>>> x = xp.asarray([[10, 8, 7, 5, 4], [0, 1, 2, 3, 5]])
298+
>>> xpx.quantile(x, 0.5, axis=-1)
299+
Array([7., 2.], dtype=array_api_strict.float64)
300+
>>> xpx.quantile(x, [0.25, 0.75], axis=-1)
301+
Array([[5., 8.],
302+
[1., 3.]], dtype=array_api_strict.float64)
303+
"""
304+
xp = array_namespace(x, q) if xp is None else xp
305+
306+
try:
307+
import scipy
308+
from packaging import version
309+
# The quantile function in scipy 1.17 supports array API directly, no need to delegate
310+
if version.parse(scipy.__version__) >= version.parse("1.17"):
311+
from scipy.stats import quantile as scipy_quantile
312+
return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method)
313+
except (ImportError, AttributeError):
314+
pass
315+
316+
return _funcs.quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"kron",
2929
"nunique",
3030
"pad",
31+
"quantile",
3132
"setdiff1d",
3233
"sinc",
3334
]
@@ -988,3 +989,166 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
988989
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
989990
)
990991
return xp.sin(y) / y
992+
993+
994+
def quantile(
995+
x: Array,
996+
q: Array | float,
997+
/,
998+
*,
999+
axis: int | None = None,
1000+
keepdims: bool = False,
1001+
method: str = "linear",
1002+
xp: ModuleType | None = None,
1003+
) -> Array:
1004+
"""See docstring in `array_api_extra._delegation.py`."""
1005+
if xp is None:
1006+
xp = array_namespace(x, q)
1007+
1008+
# Convert q to array if it's a scalar
1009+
q_is_scalar = isinstance(q, (int, float))
1010+
if q_is_scalar:
1011+
q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x))
1012+
1013+
# Validate inputs
1014+
if not xp.isdtype(x.dtype, ("integral", "real floating")):
1015+
raise ValueError("`x` must have real dtype.")
1016+
if not xp.isdtype(q.dtype, "real floating"):
1017+
raise ValueError("`q` must have real floating dtype.")
1018+
1019+
# Promote to common dtype
1020+
x = xp.astype(x, xp.float64)
1021+
q = xp.astype(q, xp.float64)
1022+
q = xp.asarray(q, device=_compat.device(x))
1023+
1024+
dtype = x.dtype
1025+
axis_none = axis is None
1026+
ndim = max(x.ndim, q.ndim)
1027+
1028+
if axis_none:
1029+
x = xp.reshape(x, (-1,))
1030+
q = xp.reshape(q, (-1,))
1031+
axis = 0
1032+
elif not isinstance(axis, int):
1033+
raise ValueError("`axis` must be an integer or None.")
1034+
elif axis >= ndim or axis < -ndim:
1035+
raise ValueError("`axis` is not compatible with the shapes of the inputs.")
1036+
else:
1037+
axis = int(axis)
1038+
1039+
# Validate method
1040+
methods = {
1041+
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
1042+
'hazen', 'interpolated_inverted_cdf', 'linear', 'median_unbiased',
1043+
'normal_unbiased', 'weibull', 'harrell-davis'
1044+
}
1045+
if method not in methods:
1046+
raise ValueError(f"`method` must be one of {methods}")
1047+
1048+
# Handle keepdims parameter
1049+
if keepdims not in {None, True, False}:
1050+
raise ValueError("If specified, `keepdims` must be True or False.")
1051+
1052+
# Handle empty arrays
1053+
if x.shape[axis] == 0:
1054+
shape = list(x.shape)
1055+
shape[axis] = 1
1056+
x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x))
1057+
1058+
# Sort the data
1059+
y = xp.sort(x, axis=axis)
1060+
1061+
# Move axis to the end for easier processing
1062+
y = xp.moveaxis(y, axis, -1)
1063+
if not (q_is_scalar or q.ndim == 0):
1064+
q = xp.moveaxis(q, axis, -1)
1065+
1066+
# Get the number of elements along the axis
1067+
n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y))
1068+
1069+
# Apply quantile calculation based on method
1070+
if method in {'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
1071+
'hazen', 'interpolated_inverted_cdf', 'linear', 'median_unbiased',
1072+
'normal_unbiased', 'weibull'}:
1073+
res = _quantile_hf(y, q, n, method, xp)
1074+
elif method == 'harrell-davis':
1075+
res = _quantile_hd(y, q, n, xp)
1076+
else:
1077+
raise ValueError(f"Unknown method: {method}")
1078+
1079+
# Handle NaN output for invalid q values
1080+
p_mask = (q > 1) | (q < 0) | xp.isnan(q)
1081+
if xp.any(p_mask):
1082+
res = xp.asarray(res, copy=True)
1083+
res = at(res, p_mask).set(xp.nan)
1084+
1085+
# Reshape per axis/keepdims
1086+
if axis_none and keepdims:
1087+
shape = (1,) * (ndim - 1) + res.shape
1088+
res = xp.reshape(res, shape)
1089+
axis = -1
1090+
1091+
# Move axis back to original position
1092+
res = xp.moveaxis(res, -1, axis)
1093+
1094+
# Handle keepdims
1095+
if not keepdims and res.shape[axis] == 1:
1096+
res = xp.squeeze(res, axis=axis)
1097+
1098+
# For scalar q, ensure we return a scalar result
1099+
if q_is_scalar:
1100+
if hasattr(res, 'shape') and res.shape != ():
1101+
res = res[()]
1102+
1103+
return res
1104+
1105+
1106+
def _quantile_hf(y: Array, p: Array, n: Array, method: str, xp: ModuleType) -> Array:
1107+
"""Helper function for Hyndman-Fan quantile methods."""
1108+
ms = {
1109+
'inverted_cdf': 0,
1110+
'averaged_inverted_cdf': 0,
1111+
'closest_observation': -0.5,
1112+
'interpolated_inverted_cdf': 0,
1113+
'hazen': 0.5,
1114+
'weibull': p,
1115+
'linear': 1 - p,
1116+
'median_unbiased': p/3 + 1/3,
1117+
'normal_unbiased': p/4 + 3/8
1118+
}
1119+
m = ms[method]
1120+
1121+
jg = p * n + m - 1
1122+
j = xp.astype(jg // 1, xp.int64) # Convert to integer
1123+
g = jg % 1
1124+
1125+
if method == 'inverted_cdf':
1126+
g = xp.astype((g > 0), jg.dtype)
1127+
elif method == 'averaged_inverted_cdf':
1128+
g = (1 + xp.astype((g > 0), jg.dtype)) / 2
1129+
elif method == 'closest_observation':
1130+
g = (1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype))
1131+
if method in {'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation'}:
1132+
g = xp.asarray(g)
1133+
g = at(g, jg < 0).set(0)
1134+
g = at(g, j < 0).set(0)
1135+
j = xp.clip(j, 0, n - 1)
1136+
jp1 = xp.clip(j + 1, 0, n - 1)
1137+
1138+
# Broadcast indices to match y shape except for the last axis
1139+
if y.ndim > 1:
1140+
# Create broadcast shape for indices
1141+
broadcast_shape = list(y.shape[:-1]) + [1]
1142+
j = xp.broadcast_to(j, broadcast_shape)
1143+
jp1 = xp.broadcast_to(jp1, broadcast_shape)
1144+
g = xp.broadcast_to(g, broadcast_shape)
1145+
1146+
return ((1 - g) * xp.take_along_axis(y, j, axis=-1) +
1147+
g * xp.take_along_axis(y, jp1, axis=-1))
1148+
1149+
1150+
def _quantile_hd(y: Array, p: Array, n: Array, xp: ModuleType) -> Array:
1151+
"""Helper function for Harrell-Davis quantile method."""
1152+
# For now, implement a simplified version that falls back to linear method
1153+
# since betainc is not available in the array API standard
1154+
return _quantile_hf(y, p, n, "linear", xp)

tests/test_funcs.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
nunique,
2525
one_hot,
2626
pad,
27+
quantile,
2728
setdiff1d,
2829
sinc,
2930
)
@@ -43,6 +44,7 @@
4344
lazy_xp_function(nunique)
4445
lazy_xp_function(one_hot)
4546
lazy_xp_function(pad)
47+
lazy_xp_function(quantile)
4648
# FIXME calls in1d which calls xp.unique_values without size
4749
lazy_xp_function(setdiff1d, jax_jit=False)
4850
lazy_xp_function(sinc)
@@ -1162,3 +1164,70 @@ def test_device(self, xp: ModuleType, device: Device):
11621164

11631165
def test_xp(self, xp: ModuleType):
11641166
xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))
1167+
1168+
1169+
class TestQuantile:
1170+
def test_basic(self, xp: ModuleType):
1171+
x = xp.asarray([1, 2, 3, 4, 5])
1172+
actual = quantile(x, 0.5)
1173+
expect = xp.asarray(3.0)
1174+
xp_assert_close(actual, expect)
1175+
1176+
def test_multiple_quantiles(self, xp: ModuleType):
1177+
x = xp.asarray([1, 2, 3, 4, 5])
1178+
actual = quantile(x, xp.asarray([0.25, 0.5, 0.75]))
1179+
expect = xp.asarray([2.0, 3.0, 4.0])
1180+
xp_assert_close(actual, expect)
1181+
1182+
def test_2d_axis(self, xp: ModuleType):
1183+
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
1184+
actual = quantile(x, 0.5, axis=0)
1185+
expect = xp.asarray([2.5, 3.5, 4.5])
1186+
xp_assert_close(actual, expect)
1187+
1188+
def test_2d_axis_keepdims(self, xp: ModuleType):
1189+
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
1190+
actual = quantile(x, 0.5, axis=0, keepdims=True)
1191+
expect = xp.asarray([[2.5, 3.5, 4.5]])
1192+
xp_assert_close(actual, expect)
1193+
1194+
def test_methods(self, xp: ModuleType):
1195+
x = xp.asarray([1, 2, 3, 4, 5])
1196+
methods = ['linear', 'hazen', 'weibull']
1197+
for method in methods:
1198+
actual = quantile(x, 0.5, method=method)
1199+
# All methods should give reasonable results
1200+
assert 2.5 <= float(actual) <= 3.5
1201+
1202+
def test_edge_cases(self, xp: ModuleType):
1203+
x = xp.asarray([1, 2, 3, 4, 5])
1204+
# q = 0 should give minimum
1205+
actual = quantile(x, 0.0)
1206+
expect = xp.asarray(1.0)
1207+
xp_assert_close(actual, expect)
1208+
1209+
# q = 1 should give maximum
1210+
actual = quantile(x, 1.0)
1211+
expect = xp.asarray(5.0)
1212+
xp_assert_close(actual, expect)
1213+
1214+
def test_invalid_q(self, xp: ModuleType):
1215+
x = xp.asarray([1, 2, 3, 4, 5])
1216+
# q > 1 should return NaN
1217+
actual = quantile(x, 1.5)
1218+
assert xp.isnan(actual)
1219+
1220+
# q < 0 should return NaN
1221+
actual = quantile(x, -0.5)
1222+
assert xp.isnan(actual)
1223+
1224+
def test_device(self, xp: ModuleType, device: Device):
1225+
x = xp.asarray([1, 2, 3, 4, 5], device=device)
1226+
actual = quantile(x, 0.5)
1227+
assert get_device(actual) == device
1228+
1229+
def test_xp(self, xp: ModuleType):
1230+
x = xp.asarray([1, 2, 3, 4, 5])
1231+
actual = quantile(x, 0.5, xp=xp)
1232+
expect = xp.asarray(3.0)
1233+
xp_assert_close(actual, expect)

0 commit comments

Comments
 (0)