@@ -383,18 +383,18 @@ def one_hot(
383383 xp : ModuleType ,
384384) -> Array : # numpydoc ignore=PR01,RT01
385385 """See docstring in `array_api_extra._delegation.py`."""
386- x_size = _compat .size (x )
387- if x_size is None : # pragma: no cover
388- # This cannot be tested because there is no way to create an array with abstract
389- # size today. However, it is blocked for the sake of type-checking and
390- # future-proofing since x.size is allowed to be None according to the
391- # specification.
392- msg = "x must have a concrete size."
393- raise TypeError (msg )
394386 # TODO: Benchmark whether this is faster on the NumPy backend:
387+ # x_size = _compat.size(x)
388+ # if x_size is None: # pragma: no cover
389+ # # This cannot be tested because there is no way to create an array with abstract
390+ # # size today. However, it is blocked for the sake of type-checking and
391+ # # future-proofing since x.size is allowed to be None according to the
392+ # # specification.
393+ # msg = "x must have a concrete size."
394+ # raise TypeError(msg)
395395 # x_flattened = xp.reshape(x, (-1,))
396396 # out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
397- # out = at(out)[xp.arange(x_size ), x_flattened].set(1)
397+ # out = at(out)[xp.arange(x.size ), x_flattened].set(1)
398398 # out = xp.reshape(out, (*x.shape, num_classes))
399399 range_num_classes = xp .arange (num_classes , dtype = x .dtype , device = _compat .device (x ))
400400 return x [..., xp .newaxis ] == range_num_classes
0 commit comments