File tree Expand file tree Collapse file tree 2 files changed +10
-11
lines changed
Expand file tree Collapse file tree 2 files changed +10
-11
lines changed Original file line number Diff line number Diff line change @@ -162,10 +162,6 @@ def one_hot(
162162 # Validate inputs.
163163 if xp is None :
164164 xp = array_namespace (x )
165- x_size = x .size
166- if x_size is None :
167- msg = "x must have a concrete size."
168- raise TypeError (msg )
169165 if not xp .isdtype (x .dtype , "integral" ):
170166 msg = "x must have an integral dtype."
171167 raise TypeError (msg )
@@ -191,7 +187,6 @@ def one_hot(
191187 out = _funcs .one_hot (
192188 x ,
193189 num_classes ,
194- x_size = x_size ,
195190 dtype = dtype ,
196191 xp = xp ,
197192 supports_fancy_indexing = is_numpy_namespace (xp ),
Original file line number Diff line number Diff line change @@ -386,20 +386,24 @@ def one_hot(
386386 * ,
387387 supports_fancy_indexing : bool = False ,
388388 supports_array_indexing : bool = False ,
389- x_size : int ,
390389 dtype : DType ,
391390 xp : ModuleType ,
392391) -> Array : # numpydoc ignore=PR01,RT01
393392 """See docstring in `array_api_extra._delegation.py`."""
393+ x_size = x .size
394+ if x_size is None : # pragma: no cover
395+ msg = "x must have a concrete size."
396+ raise TypeError (msg )
394397 out = xp .zeros ((x .size , num_classes ), dtype = dtype )
395398 x_flattened = xp .reshape (x , (- 1 ,))
396399 if supports_fancy_indexing :
397400 out = at (out )[xp .arange (x_size ), x_flattened ].set (1 )
398- for i in range (x_size ):
399- x_i = x_flattened [i ]
400- if not supports_array_indexing :
401- x_i = int (x_i )
402- out = at (out )[i , x_i ].set (1 )
401+ else :
402+ for i in range (x_size ):
403+ x_i = x_flattened [i ]
404+ if not supports_array_indexing :
405+ x_i = int (x_i )
406+ out = at (out )[i , x_i ].set (1 )
403407 if x .ndim != 1 :
404408 out = xp .reshape (out , (* x .shape , num_classes ))
405409 return out
You can’t perform that action at this time.
0 commit comments