@@ -300,14 +300,35 @@ def clip(
300
300
if min is not None and max is not None and np .any (min > max ):
301
301
raise ValueError ("min must be less than or equal to max" )
302
302
303
- result = np .clip (x ._array , min , max )
304
303
# Note: NumPy applies type promotion, but the standard specifies the
305
- # return dtype should be the same as x
306
- if result .dtype != x .dtype ._np_dtype :
307
- # TODO: I'm not completely sure this always gives the correct thing
308
- # for integer dtypes. See https://github.com/numpy/numpy/issues/24976
309
- result = result .astype (x .dtype ._np_dtype )
310
- return Array ._new (result )
304
+ # return dtype should be the same as x We do this instead of just
305
+ # downcasting the result of np.clip() to handle some corner cases better
306
+ # (e.g., avoiding uint64 -> float64 promotion).
307
+
308
+ # Note: cases where min or max overflow (integer) or round (float) in the
309
+ # wrong direction when downcasting to x.dtype are unspecified. This code
310
+ # just does whatever NumPy does when it downcasts in the assignment, but
311
+ # other behavior could be preferred, especially for integers. For example,
312
+ # this code produces:
313
+
314
+ # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
315
+ # -128
316
+
317
+ # but an answer of 0 might be preferred. See
318
+ # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
319
+ min_shape = () if min is None else min .shape
320
+ max_shape = () if max is None else max .shape
321
+ result_shape = np .broadcast_shapes (x .shape , min_shape , max_shape )
322
+ out = asarray (np .broadcast_to (x ._array , result_shape ), copy = True )._array
323
+ if min is not None :
324
+ a = np .broadcast_to (min , result_shape )
325
+ ia = (out < a ) | np .isnan (a )
326
+ out [ia ] = a [ia ]
327
+ if max is not None :
328
+ b = np .broadcast_to (max , result_shape )
329
+ ib = (out > b ) | np .isnan (b )
330
+ out [ib ] = b [ib ]
331
+ return Array ._new (out )
311
332
312
333
def conj (x : Array , / ) -> Array :
313
334
"""
0 commit comments