Skip to content

Commit 5c0877d

Browse files
committed
Simplify
1 parent deb64a0 commit 5c0877d

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

src/array_api_extra/_delegation.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,9 @@ def one_hot(
188188
out = torch_one_hot(x, num_classes)
189189
except RuntimeError as e:
190190
raise IndexError from e
191-
out = xp.astype(out, dtype)
192191
else:
193-
out = _funcs.one_hot(
194-
x,
195-
num_classes,
196-
dtype=dtype,
197-
xp=xp,
198-
)
199-
192+
out = _funcs.one_hot(x, num_classes, xp=xp)
193+
out = xp.astype(out, dtype)
200194
if axis != -1:
201195
out = xp.moveaxis(out, -1, axis)
202196
return out

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
meta_namespace,
1717
ndindex,
1818
)
19-
from ._utils._typing import Array, DType
19+
from ._utils._typing import Array
2020

2121
__all__ = [
2222
"apply_where",
@@ -380,7 +380,6 @@ def one_hot(
380380
/,
381381
num_classes: int,
382382
*,
383-
dtype: DType,
384383
xp: ModuleType,
385384
) -> Array: # numpydoc ignore=PR01,RT01
386385
"""See docstring in `array_api_extra._delegation.py`."""
@@ -398,8 +397,7 @@ def one_hot(
398397
# out = at(out)[xp.arange(x_size), x_flattened].set(1)
399398
# out = xp.reshape(out, (*x.shape, num_classes))
400399
range_num_classes = xp.arange(num_classes, dtype=x.dtype, device=_compat.device(x))
401-
out = x[..., xp.newaxis] == range_num_classes
402-
return xp.astype(out, dtype)
400+
return x[..., xp.newaxis] == range_num_classes
403401

404402

405403
def create_diagonal(

0 commit comments

Comments
 (0)