Skip to content

Commit 9577f11

Browse files
committed
Formatting
1 parent f7cac01 commit 9577f11

File tree

3 files changed

+84
-66
lines changed

3 files changed

+84
-66
lines changed

src/array_api_extra/_delegation.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,18 +270,18 @@ def quantile(
270270
Probability or sequence of probabilities of the quantiles to compute.
271271
Values must be between 0 and 1 (inclusive). Must have length 1 along
272272
`axis` unless ``keepdims=True``.
273-
method : str, default: 'linear'
274-
The method to use for estimating the quantile. The available options are:
275-
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
276-
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default),
277-
'median_unbiased', 'normal_unbiased', 'harrell-davis'.
278273
axis : int or None, default: None
279274
Axis along which the quantiles are computed. ``None`` ravels both `x`
280275
and `q` before performing the calculation.
281276
keepdims : bool, optional
282277
If this is set to True, the axes which are reduced are left in the
283278
result as dimensions with size one. With this option, the result will
284279
broadcast correctly against the original array `x`.
280+
method : str, default: 'linear'
281+
The method to use for estimating the quantile. The available options are:
282+
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
283+
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default),
284+
'median_unbiased', 'normal_unbiased', 'harrell-davis'.
285285
xp : array_namespace, optional
286286
The standard-compatible namespace for `x` and `q`. Default: infer.
287287
@@ -306,9 +306,12 @@ def quantile(
306306
try:
307307
import scipy
308308
from packaging import version
309-
# The quantile function in scipy 1.17 supports array API directly, no need to delegate
309+
310+
# The quantile function in scipy 1.17 supports array API directly, no need
311+
# to delegate
310312
if version.parse(scipy.__version__) >= version.parse("1.17"):
311313
from scipy.stats import quantile as scipy_quantile
314+
312315
return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method)
313316
except (ImportError, AttributeError):
314317
pass

src/array_api_extra/_lib/_funcs.py

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,147 +1004,162 @@ def quantile(
10041004
"""See docstring in `array_api_extra._delegation.py`."""
10051005
if xp is None:
10061006
xp = array_namespace(x, q)
1007-
1007+
10081008
# Convert q to array if it's a scalar
1009-
q_is_scalar = isinstance(q, (int, float))
1009+
q_is_scalar = isinstance(q, int | float)
10101010
if q_is_scalar:
10111011
q = xp.asarray(q, dtype=xp.float64, device=_compat.device(x))
1012-
1012+
10131013
# Validate inputs
10141014
if not xp.isdtype(x.dtype, ("integral", "real floating")):
1015-
raise ValueError("`x` must have real dtype.")
1015+
raise ValueError("`x` must have real dtype.") # noqa: EM101
10161016
if not xp.isdtype(q.dtype, "real floating"):
1017-
raise ValueError("`q` must have real floating dtype.")
1018-
1017+
raise ValueError("`q` must have real floating dtype.") # noqa: EM101
1018+
10191019
# Promote to common dtype
10201020
x = xp.astype(x, xp.float64)
10211021
q = xp.astype(q, xp.float64)
10221022
q = xp.asarray(q, device=_compat.device(x))
1023-
1023+
10241024
dtype = x.dtype
10251025
axis_none = axis is None
10261026
ndim = max(x.ndim, q.ndim)
1027-
1027+
10281028
if axis_none:
10291029
x = xp.reshape(x, (-1,))
10301030
q = xp.reshape(q, (-1,))
10311031
axis = 0
10321032
elif not isinstance(axis, int):
1033-
raise ValueError("`axis` must be an integer or None.")
1033+
raise ValueError("`axis` must be an integer or None.") # noqa: EM101
10341034
elif axis >= ndim or axis < -ndim:
1035-
raise ValueError("`axis` is not compatible with the shapes of the inputs.")
1035+
raise ValueError("`axis` is not compatible with the shapes of the inputs.") # noqa: EM101
10361036
else:
10371037
axis = int(axis)
1038-
1038+
10391039
# Validate method
10401040
methods = {
1041-
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
1042-
'hazen', 'interpolated_inverted_cdf', 'linear', 'median_unbiased',
1043-
'normal_unbiased', 'weibull', 'harrell-davis'
1041+
"inverted_cdf",
1042+
"averaged_inverted_cdf",
1043+
"closest_observation",
1044+
"hazen",
1045+
"interpolated_inverted_cdf",
1046+
"linear",
1047+
"median_unbiased",
1048+
"normal_unbiased",
1049+
"weibull",
1050+
"harrell-davis",
10441051
}
10451052
if method not in methods:
1046-
raise ValueError(f"`method` must be one of {methods}")
1047-
1053+
raise ValueError(f"`method` must be one of {methods}") # noqa: EM102
1054+
10481055
# Handle keepdims parameter
10491056
if keepdims not in {None, True, False}:
1050-
raise ValueError("If specified, `keepdims` must be True or False.")
1051-
1057+
raise ValueError("If specified, `keepdims` must be True or False.") # noqa: EM101
1058+
10521059
# Handle empty arrays
10531060
if x.shape[axis] == 0:
10541061
shape = list(x.shape)
10551062
shape[axis] = 1
10561063
x = xp.full(shape, xp.nan, dtype=dtype, device=_compat.device(x))
1057-
1064+
10581065
# Sort the data
10591066
y = xp.sort(x, axis=axis)
1060-
1067+
10611068
# Move axis to the end for easier processing
10621069
y = xp.moveaxis(y, axis, -1)
10631070
if not (q_is_scalar or q.ndim == 0):
10641071
q = xp.moveaxis(q, axis, -1)
1065-
1072+
10661073
# Get the number of elements along the axis
10671074
n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y))
1068-
1075+
10691076
# Apply quantile calculation based on method
1070-
if method in {'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
1071-
'hazen', 'interpolated_inverted_cdf', 'linear', 'median_unbiased',
1072-
'normal_unbiased', 'weibull'}:
1077+
if method in {
1078+
"inverted_cdf",
1079+
"averaged_inverted_cdf",
1080+
"closest_observation",
1081+
"hazen",
1082+
"interpolated_inverted_cdf",
1083+
"linear",
1084+
"median_unbiased",
1085+
"normal_unbiased",
1086+
"weibull",
1087+
}:
10731088
res = _quantile_hf(y, q, n, method, xp)
1074-
elif method == 'harrell-davis':
1089+
elif method == "harrell-davis":
10751090
res = _quantile_hd(y, q, n, xp)
10761091
else:
1077-
raise ValueError(f"Unknown method: {method}")
1078-
1092+
raise ValueError(f"Unknown method: {method}") # noqa: EM102
1093+
10791094
# Handle NaN output for invalid q values
10801095
p_mask = (q > 1) | (q < 0) | xp.isnan(q)
10811096
if xp.any(p_mask):
10821097
res = xp.asarray(res, copy=True)
10831098
res = at(res, p_mask).set(xp.nan)
1084-
1099+
10851100
# Reshape per axis/keepdims
10861101
if axis_none and keepdims:
10871102
shape = (1,) * (ndim - 1) + res.shape
10881103
res = xp.reshape(res, shape)
10891104
axis = -1
1090-
1105+
10911106
# Move axis back to original position
10921107
res = xp.moveaxis(res, -1, axis)
1093-
1108+
10941109
# Handle keepdims
10951110
if not keepdims and res.shape[axis] == 1:
10961111
res = xp.squeeze(res, axis=axis)
1097-
1112+
10981113
# For scalar q, ensure we return a scalar result
1099-
if q_is_scalar:
1100-
if hasattr(res, 'shape') and res.shape != ():
1101-
res = res[()]
1102-
1114+
if q_is_scalar and hasattr(res, "shape") and res.shape != ():
1115+
res = res[()]
1116+
11031117
return res
11041118

11051119

11061120
def _quantile_hf(y: Array, p: Array, n: Array, method: str, xp: ModuleType) -> Array:
11071121
"""Helper function for Hyndman-Fan quantile methods."""
11081122
ms = {
1109-
'inverted_cdf': 0,
1110-
'averaged_inverted_cdf': 0,
1111-
'closest_observation': -0.5,
1112-
'interpolated_inverted_cdf': 0,
1113-
'hazen': 0.5,
1114-
'weibull': p,
1115-
'linear': 1 - p,
1116-
'median_unbiased': p/3 + 1/3,
1117-
'normal_unbiased': p/4 + 3/8
1123+
"inverted_cdf": 0,
1124+
"averaged_inverted_cdf": 0,
1125+
"closest_observation": -0.5,
1126+
"interpolated_inverted_cdf": 0,
1127+
"hazen": 0.5,
1128+
"weibull": p,
1129+
"linear": 1 - p,
1130+
"median_unbiased": p / 3 + 1 / 3,
1131+
"normal_unbiased": p / 4 + 3 / 8,
11181132
}
11191133
m = ms[method]
1120-
1134+
11211135
jg = p * n + m - 1
11221136
j = xp.astype(jg // 1, xp.int64) # Convert to integer
11231137
g = jg % 1
1124-
1125-
if method == 'inverted_cdf':
1138+
1139+
if method == "inverted_cdf":
11261140
g = xp.astype((g > 0), jg.dtype)
1127-
elif method == 'averaged_inverted_cdf':
1141+
elif method == "averaged_inverted_cdf":
11281142
g = (1 + xp.astype((g > 0), jg.dtype)) / 2
1129-
elif method == 'closest_observation':
1130-
g = (1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype))
1131-
if method in {'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation'}:
1143+
elif method == "closest_observation":
1144+
g = 1 - xp.astype((g == 0) & (j % 2 == 1), jg.dtype)
1145+
if method in {"inverted_cdf", "averaged_inverted_cdf", "closest_observation"}:
11321146
g = xp.asarray(g)
11331147
g = at(g, jg < 0).set(0)
11341148
g = at(g, j < 0).set(0)
11351149
j = xp.clip(j, 0, n - 1)
11361150
jp1 = xp.clip(j + 1, 0, n - 1)
1137-
1151+
11381152
# Broadcast indices to match y shape except for the last axis
11391153
if y.ndim > 1:
11401154
# Create broadcast shape for indices
1141-
broadcast_shape = list(y.shape[:-1]) + [1]
1155+
broadcast_shape = list(y.shape[:-1]) + [1] # noqa: RUF005
11421156
j = xp.broadcast_to(j, broadcast_shape)
11431157
jp1 = xp.broadcast_to(jp1, broadcast_shape)
11441158
g = xp.broadcast_to(g, broadcast_shape)
1145-
1146-
return ((1 - g) * xp.take_along_axis(y, j, axis=-1) +
1147-
g * xp.take_along_axis(y, jp1, axis=-1))
1159+
1160+
return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis(
1161+
y, jp1, axis=-1
1162+
)
11481163

11491164

11501165
def _quantile_hd(y: Array, p: Array, n: Array, xp: ModuleType) -> Array:

tests/test_funcs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,7 @@ def test_2d_axis_keepdims(self, xp: ModuleType):
11931193

11941194
def test_methods(self, xp: ModuleType):
11951195
x = xp.asarray([1, 2, 3, 4, 5])
1196-
methods = ['linear', 'hazen', 'weibull']
1196+
methods = ["linear", "hazen", "weibull"]
11971197
for method in methods:
11981198
actual = quantile(x, 0.5, method=method)
11991199
# All methods should give reasonable results
@@ -1205,7 +1205,7 @@ def test_edge_cases(self, xp: ModuleType):
12051205
actual = quantile(x, 0.0)
12061206
expect = xp.asarray(1.0)
12071207
xp_assert_close(actual, expect)
1208-
1208+
12091209
# q = 1 should give maximum
12101210
actual = quantile(x, 1.0)
12111211
expect = xp.asarray(5.0)
@@ -1216,7 +1216,7 @@ def test_invalid_q(self, xp: ModuleType):
12161216
# q > 1 should return NaN
12171217
actual = quantile(x, 1.5)
12181218
assert xp.isnan(actual)
1219-
1219+
12201220
# q < 0 should return NaN
12211221
actual = quantile(x, -0.5)
12221222
assert xp.isnan(actual)

0 commit comments

Comments
 (0)