Skip to content

Commit 3ef6727

Browse files
committed
Raise exception for invalid q values
1 parent 440106f commit 3ef6727

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

src/array_api_extra/_delegation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def quantile(
331331

332332
# The quantile function in scipy 1.16 supports array API directly, no need
333333
# to delegate
334-
if version.parse(scipy.__version__) >= version.parse("1.16"): # pyright: ignore[reportUnknownArgumentType]
334+
if version.parse(scipy.__version__) >= version.parse("1.17"): # pyright: ignore[reportUnknownArgumentType]
335335
from scipy.stats import ( # type: ignore[import-untyped]
336336
quantile as scipy_quantile,
337337
)

src/array_api_extra/_lib/_quantile.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,12 @@ def quantile(
7575

7676
n = xp.asarray(y.shape[-1], dtype=dtype, device=_compat.device(y))
7777

78-
res = _quantile_hf(y, q_arr, n, method, xp)
78+
# Validate that q values are in the range [0, 1]
79+
if xp.any((q_arr < 0) | (q_arr > 1)):
80+
msg = "`q` must contain values between 0 and 1 inclusive."
81+
raise ValueError(msg)
7982

80-
# Handle NaN output for invalid q values
81-
p_mask = (q_arr > 1) | (q_arr < 0) | xp.isnan(q_arr)
82-
if xp.any(p_mask):
83-
res = xp.asarray(res, copy=True)
84-
res = at(res, p_mask).set(xp.nan)
83+
res = _quantile_hf(y, q_arr, n, method, xp)
8584

8685
# Reshape per axis/keepdims
8786
if axis_none and keepdims:
@@ -97,9 +96,10 @@ def quantile(
9796
res = xp.squeeze(res, axis=axis)
9897

9998
# For scalar q, ensure we return a scalar result
100-
if q_is_scalar and hasattr(res, "shape") and res.shape != ():
101-
res = res[()]
102-
99+
# if q_is_scalar and hasattr(res, "shape") and res.shape != ():
100+
# res = res[()]
101+
if res.ndim == 0:
102+
return res[()]
103103
return res
104104

105105

@@ -121,7 +121,10 @@ def _quantile_hf(
121121
m = ms[method]
122122

123123
jg = p * n + m - 1
124-
j = xp.astype(jg // 1, xp.int64) # Convert to integer
124+
# Convert both to integers, the type of j and n must be the same
125+
# for us to be able to `xp.clip` them.
126+
j = xp.astype(jg // 1, xp.int64)
127+
n = xp.astype(n, xp.int64)
125128
g = jg % 1
126129

127130
if method == "inverted_cdf":

tests/test_funcs.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,13 +1213,16 @@ def test_edge_cases(self, xp: ModuleType):
12131213

12141214
def test_invalid_q(self, xp: ModuleType):
12151215
x = xp.asarray([1, 2, 3, 4, 5])
1216-
# q > 1 should return NaN
1217-
actual = quantile(x, 1.5)
1218-
assert xp.isnan(actual)
1216+
# q > 1 should raise
1217+
with pytest.raises(
1218+
ValueError, match="`q` must contain values between 0 and 1 inclusive"
1219+
):
1220+
quantile(x, 1.5)
12191221

1220-
# q < 0 should return NaN
1221-
actual = quantile(x, -0.5)
1222-
assert xp.isnan(actual)
1222+
with pytest.raises(
1223+
ValueError, match="`q` must contain values between 0 and 1 inclusive"
1224+
):
1225+
quantile(x, -0.5)
12231226

12241227
def test_device(self, xp: ModuleType, device: Device):
12251228
x = xp.asarray([1, 2, 3, 4, 5], device=device)

0 commit comments

Comments
 (0)