Skip to content

Commit 3d313e2

Browse files
committed
Fix clip when the input is uint64
NumPy type promotes this to float64, which is not what we want.
1 parent 6b0079b commit 3d313e2

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

array_api_strict/_elementwise_functions.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,35 @@ def clip(
300300
if min is not None and max is not None and np.any(min > max):
301301
raise ValueError("min must be less than or equal to max")
302302

303-
result = np.clip(x._array, min, max)
304303
# Note: NumPy applies type promotion, but the standard specifies the
305-
# return dtype should be the same as x
306-
if result.dtype != x.dtype._np_dtype:
307-
# TODO: I'm not completely sure this always gives the correct thing
308-
# for integer dtypes. See https://github.com/numpy/numpy/issues/24976
309-
result = result.astype(x.dtype._np_dtype)
310-
return Array._new(result)
304+
# return dtype should be the same as x We do this instead of just
305+
# downcasting the result of np.clip() to handle some corner cases better
306+
# (e.g., avoiding uint64 -> float64 promotion).
307+
308+
# Note: cases where min or max overflow (integer) or round (float) in the
309+
# wrong direction when downcasting to x.dtype are unspecified. This code
310+
# just does whatever NumPy does when it downcasts in the assignment, but
311+
# other behavior could be preferred, especially for integers. For example,
312+
# this code produces:
313+
314+
# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
315+
# -128
316+
317+
# but an answer of 0 might be preferred. See
318+
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
319+
min_shape = () if min is None else min.shape
320+
max_shape = () if max is None else max.shape
321+
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
322+
out = asarray(np.broadcast_to(x._array, result_shape), copy=True)._array
323+
if min is not None:
324+
a = np.broadcast_to(min, result_shape)
325+
ia = (out < a) | np.isnan(a)
326+
out[ia] = a[ia]
327+
if max is not None:
328+
b = np.broadcast_to(max, result_shape)
329+
ib = (out > b) | np.isnan(b)
330+
out[ib] = b[ib]
331+
return Array._new(out)
311332

312333
def conj(x: Array, /) -> Array:
313334
"""

0 commit comments

Comments
 (0)