@@ -368,23 +368,24 @@ def _isscalar(a):
368368 if type (max ) is int and max >= wrapped_xp .iinfo (x .dtype ).max :
369369 max = None
370370
371+ dev = device (x )
371372 if out is None :
372- out = wrapped_xp .asarray (xp .broadcast_to (x , result_shape ),
373- copy = True , device = device (x ))
373+ out = wrapped_xp .empty (result_shape , dtype = x .dtype , device = dev )
374+ out [()] = x
375+
374376 if min is not None :
375- if is_torch_array (x ) and x .dtype == xp .float64 and _isscalar (min ):
376- # Avoid loss of precision due to torch defaulting to float32
377- min = wrapped_xp .asarray (min , dtype = xp .float64 )
378- a = xp .broadcast_to (wrapped_xp .asarray (min , device = device (x )), result_shape )
377+ a = wrapped_xp .asarray (min , dtype = x .dtype , device = dev )
378+ a = xp .broadcast_to (a , result_shape )
379379 ia = (out < a ) | xp .isnan (a )
380380 # torch requires an explicit cast here
381- out [ia ] = wrapped_xp .astype (a [ia ], out .dtype )
381+ out [ia ] = a [ia ]
382+
382383 if max is not None :
383- if is_torch_array (x ) and x .dtype == xp .float64 and _isscalar (max ):
384- max = wrapped_xp .asarray (max , dtype = xp .float64 )
385- b = xp .broadcast_to (wrapped_xp .asarray (max , device = device (x )), result_shape )
384+ b = wrapped_xp .asarray (max , dtype = x .dtype , device = dev )
385+ b = xp .broadcast_to (b , result_shape )
386386 ib = (out > b ) | xp .isnan (b )
387- out [ib ] = wrapped_xp .astype (b [ib ], out .dtype )
387+ out [ib ] = b [ib ]
388+
388389 # Return a scalar for 0-D
389390 return out [()]
390391
0 commit comments