Skip to content

Commit 4bf80dc

Browse files
committed
Simplify more
1 parent 5c0877d commit 4bf80dc

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -383,18 +383,18 @@ def one_hot(
383383
xp: ModuleType,
384384
) -> Array: # numpydoc ignore=PR01,RT01
385385
"""See docstring in `array_api_extra._delegation.py`."""
386-
x_size = _compat.size(x)
387-
if x_size is None: # pragma: no cover
388-
# This cannot be tested because there is no way to create an array with abstract
389-
# size today. However, it is blocked for the sake of type-checking and
390-
# future-proofing since x.size is allowed to be None according to the
391-
# specification.
392-
msg = "x must have a concrete size."
393-
raise TypeError(msg)
394386
# TODO: Benchmark whether this is faster on the NumPy backend:
387+
# x_size = _compat.size(x)
388+
# if x_size is None: # pragma: no cover
389+
# # This cannot be tested because there is no way to create an array with abstract
390+
# # size today. However, it is blocked for the sake of type-checking and
391+
# # future-proofing since x.size is allowed to be None according to the
392+
# # specification.
393+
# msg = "x must have a concrete size."
394+
# raise TypeError(msg)
395395
# x_flattened = xp.reshape(x, (-1,))
396396
# out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
397-
# out = at(out)[xp.arange(x_size), x_flattened].set(1)
397+
# out = at(out)[xp.arange(x.size), x_flattened].set(1)
398398
# out = xp.reshape(out, (*x.shape, num_classes))
399399
range_num_classes = xp.arange(num_classes, dtype=x.dtype, device=_compat.device(x))
400400
return x[..., xp.newaxis] == range_num_classes

0 commit comments

Comments
 (0)