File tree Expand file tree Collapse file tree 2 files changed +4
-1
lines changed
Expand file tree Collapse file tree 2 files changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -135,6 +135,9 @@ def one_hot(
135135 An array with integral dtype having shape ``batch_dims``.
136136 num_classes : int
137137 Number of classes in the one-hot dimension.
138+ dtype : DType, optional
139+ The dtype of the return value. Defaults to the default float dtype (usually
140+ float64).
138141 axis : int or tuple of ints, optional
139142 Position(s) in the expanded axes where the new axis is placed.
140143 xp : array_namespace, optional
@@ -146,7 +149,6 @@ def one_hot(
146149 An array having the same shape as `x` except for a new axis at the position
147150 given by `axis` having size `num_classes`.
148151
149- The dtype of the return value is the default float dtype (usually float64).
150152
151153 If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
152154 an exception, or may even cause a bad state. `x` is not checked.
Original file line number Diff line number Diff line change @@ -389,6 +389,7 @@ def one_hot(
389389 dtype : DType ,
390390 xp : ModuleType ,
391391) -> Array :
392+ """Helper for _delegation.one_hot."""
392393 out = xp .zeros ((x .size , num_classes ), dtype = dtype )
393394 x_flattened = xp .reshape (x , (- 1 ,))
394395 if is_numpy_namespace (xp ):
You can’t perform that action at this time.
0 commit comments