File tree Expand file tree Collapse file tree 2 files changed +4
-12
lines changed
Expand file tree Collapse file tree 2 files changed +4
-12
lines changed Original file line number Diff line number Diff line change @@ -188,15 +188,9 @@ def one_hot(
188188 out = torch_one_hot (x , num_classes )
189189 except RuntimeError as e :
190190 raise IndexError from e
191- out = xp .astype (out , dtype )
192191 else :
193- out = _funcs .one_hot (
194- x ,
195- num_classes ,
196- dtype = dtype ,
197- xp = xp ,
198- )
199-
192+ out = _funcs .one_hot (x , num_classes , xp = xp )
193+ out = xp .astype (out , dtype )
200194 if axis != - 1 :
201195 out = xp .moveaxis (out , - 1 , axis )
202196 return out
Original file line number Diff line number Diff line change 1616 meta_namespace ,
1717 ndindex ,
1818)
19- from ._utils ._typing import Array , DType
19+ from ._utils ._typing import Array
2020
2121__all__ = [
2222 "apply_where" ,
@@ -380,7 +380,6 @@ def one_hot(
380380 / ,
381381 num_classes : int ,
382382 * ,
383- dtype : DType ,
384383 xp : ModuleType ,
385384) -> Array : # numpydoc ignore=PR01,RT01
386385 """See docstring in `array_api_extra._delegation.py`."""
@@ -398,8 +397,7 @@ def one_hot(
398397 # out = at(out)[xp.arange(x_size), x_flattened].set(1)
399398 # out = xp.reshape(out, (*x.shape, num_classes))
400399 range_num_classes = xp .arange (num_classes , dtype = x .dtype , device = _compat .device (x ))
401- out = x [..., xp .newaxis ] == range_num_classes
402- return xp .astype (out , dtype )
400+ return x [..., xp .newaxis ] == range_num_classes
403401
404402
405403def create_diagonal (
You can’t perform that action at this time.
0 commit comments