Skip to content

Commit 007a61f

Browse files
committed
Tweak
1 parent 13a5507 commit 007a61f

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

src/array_api_extra/_lib/_quantile.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,9 @@ def quantile(
9191
# Move axis back to original position
9292
res = xp.moveaxis(res, -1, axis)
9393

94-
# Handle keepdims
9594
if not keepdims and res.shape[axis] == 1:
9695
res = xp.squeeze(res, axis=axis)
9796

98-
# For scalar q, ensure we return a scalar result
99-
# if q_is_scalar and hasattr(res, "shape") and res.shape != ():
100-
# res = res[()]
10197
if res.ndim == 0:
10298
return res[()]
10399
return res
@@ -148,6 +144,7 @@ def _quantile_hf(
148144
jp1 = xp.broadcast_to(jp1, broadcast_shape)
149145
g = xp.broadcast_to(g, broadcast_shape)
150146

151-
return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis(
147+
res = (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis(
152148
y, jp1, axis=-1
153149
)
150+
return res

0 commit comments

Comments
 (0)