Skip to content

Commit c3501e8

Browse files
committed
Move quantile implementation to new file
_funcs.py was getting too long.
1 parent 37acd5b commit c3501e8

File tree

3 files changed

+158
-150
lines changed

3 files changed

+158
-150
lines changed

src/array_api_extra/_delegation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Literal
66

77
from ._lib import _funcs
8+
from ._lib._quantile import quantile as _quantile
89
from ._lib._utils._compat import (
910
array_namespace,
1011
is_cupy_namespace,
@@ -332,4 +333,4 @@ def quantile(
332333
except (ImportError, AttributeError):
333334
pass
334335

335-
return _funcs.quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp)
336+
return _quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 0 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"kron",
2929
"nunique",
3030
"pad",
31-
"quantile",
3231
"setdiff1d",
3332
"sinc",
3433
]
@@ -989,151 +988,3 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
989988
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
990989
)
991990
return xp.sin(y) / y
992-
993-
994-
def quantile(
995-
x: Array,
996-
q: Array | float,
997-
/,
998-
*,
999-
axis: int | None = None,
1000-
keepdims: bool = False,
1001-
method: str = "linear",
1002-
xp: ModuleType | None = None,
1003-
) -> Array: # numpydoc ignore=PR01,RT01
1004-
"""See docstring in `array_api_extra._delegation.py`."""
1005-
if xp is None:
1006-
xp = array_namespace(x, q)
1007-
1008-
q_is_scalar = isinstance(q, int | float)
1009-
if q_is_scalar:
1010-
q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x))
1011-
1012-
if not xp.isdtype(x.dtype, ("integral", "real floating")):
1013-
raise ValueError("`x` must have real dtype.") # noqa: EM101
1014-
if not xp.isdtype(q.dtype, "real floating"):
1015-
raise ValueError("`q` must have real floating dtype.") # noqa: EM101
1016-
1017-
# Promote to common dtype
1018-
x = xp.astype(x, xp.float64)
1019-
q = xp.astype(q, xp.float64)
1020-
q = xp.asarray(q, device=_compat.device(x))
1021-
1022-
dtype = x.dtype
1023-
axis_none = axis is None
1024-
ndim = max(x.ndim, q.ndim)
1025-
1026-
if axis_none:
1027-
x = xp.reshape(x, (-1,))
1028-
q = xp.reshape(q, (-1,))
1029-
axis = 0
1030-
elif not isinstance(axis, int):
1031-
raise ValueError("`axis` must be an integer or None.") # noqa: EM101
1032-
elif axis >= ndim or axis < -ndim:
1033-
raise ValueError("`axis` is not compatible with the shapes of the inputs.") # noqa: EM101
1034-
else:
1035-
axis = int(axis)
1036-
1037-
if keepdims not in {None, True, False}:
1038-
raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101
1039-
1040-
if x.shape[axis] == 0:
1041-
shape = list(x.shape)
1042-
shape[axis] = 1
1043-
x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x))
1044-
1045-
y = xp.sort(x, axis=axis)
1046-
1047-
# Move axis to the end for easier processing
1048-
y = xp.moveaxis(y, axis, -1)
1049-
if not (q_is_scalar or q.ndim == 0):
1050-
q = xp.moveaxis(q, axis, -1)
1051-
1052-
n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y))
1053-
1054-
if method in { # pylint: disable=duplicate-code
1055-
"inverted_cdf",
1056-
"averaged_inverted_cdf",
1057-
"closest_observation",
1058-
"hazen",
1059-
"interpolated_inverted_cdf",
1060-
"linear",
1061-
"median_unbiased",
1062-
"normal_unbiased",
1063-
"weibull",
1064-
}:
1065-
res = _quantile_hf(y, q, n, method, xp)
1066-
else:
1067-
raise ValueError(f"Unknown method: {method}") # noqa: EM102
1068-
1069-
# Handle NaN output for invalid q values
1070-
p_mask = (q > 1) | (q < 0) | xp.isnan(q)
1071-
if xp.any(p_mask):
1072-
res = xp.asarray(res, copy=True)
1073-
res = at(res, p_mask).set(xp.nan)
1074-
1075-
# Reshape per axis/keepdims
1076-
if axis_none and keepdims:
1077-
shape = (1,) * (ndim - 1) + res.shape
1078-
res = xp.reshape(res, shape)
1079-
axis = -1
1080-
1081-
# Move axis back to original position
1082-
res = xp.moveaxis(res, -1, axis)
1083-
1084-
# Handle keepdims
1085-
if not keepdims and res.shape[axis] == 1:
1086-
res = xp.squeeze(res, axis=axis)
1087-
1088-
# For scalar q, ensure we return a scalar result
1089-
if q_is_scalar and hasattr(res, "shape") and res.shape != ():
1090-
res = res[()]
1091-
1092-
return res
1093-
1094-
1095-
def _quantile_hf(
1096-
y: Array, p: Array, n: Array, method: str, xp: ModuleType
1097-
) -> Array: # numpydoc ignore=PR01,RT01
1098-
"""Helper function for Hyndman-Fan quantile method."""
1099-
ms = {
1100-
"inverted_cdf": 0,
1101-
"averaged_inverted_cdf": 0,
1102-
"closest_observation": -0.5,
1103-
"interpolated_inverted_cdf": 0,
1104-
"hazen": 0.5,
1105-
"weibull": p,
1106-
"linear": 1 - p,
1107-
"median_unbiased": p / 3 + 1 / 3,
1108-
"normal_unbiased": p / 4 + 3 / 8,
1109-
}
1110-
m = ms[method]
1111-
1112-
jg = p * n + m - 1
1113-
j = xp.astype(jg // 1, xp.int64) # Convert to integer
1114-
g = jg % 1
1115-
1116-
if method == "inverted_cdf":
1117-
g = xp.astype((g > 0), jg.dtype)
1118-
elif method == "averaged_inverted_cdf":
1119-
g = (1 + xp.astype((g > 0), jg.dtype)) / 2
1120-
elif method == "closest_observation":
1121-
g = 1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype)
1122-
if method in {"inverted_cdf", "averaged_inverted_cdf", "closest_observation"}:
1123-
g = xp.asarray(g)
1124-
g = at(g, jg < 0).set(0)
1125-
g = at(g, j < 0).set(0)
1126-
j = xp.clip(j, 0, n - 1)
1127-
jp1 = xp.clip(j + 1, 0, n - 1)
1128-
1129-
# Broadcast indices to match y shape except for the last axis
1130-
if y.ndim > 1:
1131-
# Create broadcast shape for indices
1132-
broadcast_shape = list(y.shape[:-1]) + [1] # noqa: RUF005
1133-
j = xp.broadcast_to(j, broadcast_shape)
1134-
jp1 = xp.broadcast_to(jp1, broadcast_shape)
1135-
g = xp.broadcast_to(g, broadcast_shape)
1136-
1137-
return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis(
1138-
y, jp1, axis=-1
1139-
)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""Quantile implementation."""
2+
3+
from types import ModuleType
4+
5+
from ._at import at
6+
from ._utils import _compat
7+
from ._utils._compat import array_namespace
8+
from ._utils._typing import Array
9+
10+
11+
def quantile(
12+
x: Array,
13+
q: Array | float,
14+
/,
15+
*,
16+
axis: int | None = None,
17+
keepdims: bool = False,
18+
method: str = "linear",
19+
xp: ModuleType | None = None,
20+
) -> Array: # numpydoc ignore=PR01,RT01
21+
"""See docstring in `array_api_extra._delegation.py`."""
22+
if xp is None:
23+
xp = array_namespace(x, q)
24+
25+
q_is_scalar = isinstance(q, int | float)
26+
if q_is_scalar:
27+
q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x))
28+
29+
if not xp.isdtype(x.dtype, ("integral", "real floating")):
30+
raise ValueError("`x` must have real dtype.") # noqa: EM101
31+
if not xp.isdtype(q.dtype, "real floating"):
32+
raise ValueError("`q` must have real floating dtype.") # noqa: EM101
33+
34+
# Promote to common dtype
35+
x = xp.astype(x, xp.float64)
36+
q = xp.astype(q, xp.float64)
37+
q = xp.asarray(q, device=_compat.device(x))
38+
39+
dtype = x.dtype
40+
axis_none = axis is None
41+
ndim = max(x.ndim, q.ndim)
42+
43+
if axis_none:
44+
x = xp.reshape(x, (-1,))
45+
q = xp.reshape(q, (-1,))
46+
axis = 0
47+
elif not isinstance(axis, int):
48+
raise ValueError("`axis` must be an integer or None.") # noqa: EM101
49+
elif axis >= ndim or axis < -ndim:
50+
raise ValueError("`axis` is not compatible with the shapes of the inputs.") # noqa: EM101
51+
else:
52+
axis = int(axis)
53+
54+
if keepdims not in {None, True, False}:
55+
raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101
56+
57+
if x.shape[axis] == 0:
58+
shape = list(x.shape)
59+
shape[axis] = 1
60+
x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x))
61+
62+
y = xp.sort(x, axis=axis)
63+
64+
# Move axis to the end for easier processing
65+
y = xp.moveaxis(y, axis, -1)
66+
if not (q_is_scalar or q.ndim == 0):
67+
q = xp.moveaxis(q, axis, -1)
68+
69+
n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y))
70+
71+
if method in { # pylint: disable=duplicate-code
72+
"inverted_cdf",
73+
"averaged_inverted_cdf",
74+
"closest_observation",
75+
"hazen",
76+
"interpolated_inverted_cdf",
77+
"linear",
78+
"median_unbiased",
79+
"normal_unbiased",
80+
"weibull",
81+
}:
82+
res = _quantile_hf(y, q, n, method, xp)
83+
else:
84+
raise ValueError(f"Unknown method: {method}") # noqa: EM102
85+
86+
# Handle NaN output for invalid q values
87+
p_mask = (q > 1) | (q < 0) | xp.isnan(q)
88+
if xp.any(p_mask):
89+
res = xp.asarray(res, copy=True)
90+
res = at(res, p_mask).set(xp.nan)
91+
92+
# Reshape per axis/keepdims
93+
if axis_none and keepdims:
94+
shape = (1,) * (ndim - 1) + res.shape
95+
res = xp.reshape(res, shape)
96+
axis = -1
97+
98+
# Move axis back to original position
99+
res = xp.moveaxis(res, -1, axis)
100+
101+
# Handle keepdims
102+
if not keepdims and res.shape[axis] == 1:
103+
res = xp.squeeze(res, axis=axis)
104+
105+
# For scalar q, ensure we return a scalar result
106+
if q_is_scalar and hasattr(res, "shape") and res.shape != ():
107+
res = res[()]
108+
109+
return res
110+
111+
112+
def _quantile_hf(
113+
y: Array, p: Array, n: Array, method: str, xp: ModuleType
114+
) -> Array: # numpydoc ignore=PR01,RT01
115+
"""Helper function for Hyndman-Fan quantile method."""
116+
ms = {
117+
"inverted_cdf": 0,
118+
"averaged_inverted_cdf": 0,
119+
"closest_observation": -0.5,
120+
"interpolated_inverted_cdf": 0,
121+
"hazen": 0.5,
122+
"weibull": p,
123+
"linear": 1 - p,
124+
"median_unbiased": p / 3 + 1 / 3,
125+
"normal_unbiased": p / 4 + 3 / 8,
126+
}
127+
m = ms[method]
128+
129+
jg = p * n + m - 1
130+
j = xp.astype(jg // 1, xp.int64) # Convert to integer
131+
g = jg % 1
132+
133+
if method == "inverted_cdf":
134+
g = xp.astype((g > 0), jg.dtype)
135+
elif method == "averaged_inverted_cdf":
136+
g = (1 + xp.astype((g > 0), jg.dtype)) / 2
137+
elif method == "closest_observation":
138+
g = 1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype)
139+
if method in {"inverted_cdf", "averaged_inverted_cdf", "closest_observation"}:
140+
g = xp.asarray(g)
141+
g = at(g, jg < 0).set(0)
142+
g = at(g, j < 0).set(0)
143+
j = xp.clip(j, 0, n - 1)
144+
jp1 = xp.clip(j + 1, 0, n - 1)
145+
146+
# Broadcast indices to match y shape except for the last axis
147+
if y.ndim > 1:
148+
# Create broadcast shape for indices
149+
broadcast_shape = list(y.shape[:-1]) + [1] # noqa: RUF005
150+
j = xp.broadcast_to(j, broadcast_shape)
151+
jp1 = xp.broadcast_to(jp1, broadcast_shape)
152+
g = xp.broadcast_to(g, broadcast_shape)
153+
154+
return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis(
155+
y, jp1, axis=-1
156+
)

0 commit comments

Comments
 (0)