File tree Expand file tree Collapse file tree 1 file changed +35
-0
lines changed
Expand file tree Collapse file tree 1 file changed +35
-0
lines changed Original file line number Diff line number Diff line change @@ -392,6 +392,41 @@ def one_hot(
392392 axis : int = - 1 ,
393393 xp : ModuleType | None = None ,
394394) -> Array :
395+ """
396+ One-hot encode the given indices.
397+
398+ Each index in the input ``x`` is encoded as a vector of zeros of length
399+ ``num_classes`` with the element at the given index set to one.
400+
401+ Parameters
402+ ----------
403+ x : array
404+ An array with integral dtype having shape ``batch_dims``.
405+ num_classes : int
406+ Number of classes in the one-hot dimension.
407+ axis : int or tuple of ints, optional
408+ Position(s) in the expanded axes where the new axis is placed.
409+ xp : array_namespace, optional
410+ The standard-compatible namespace for `x`. Default: infer.
411+
412+ Returns
413+ -------
414+ array
415+ An array having the same shape as `x` except for a new axis at the position
416+ given by `axis` having size `num_classes`.
417+
418+ The dtype of the return value is the default float dtype (usually float64).
419+
420+ If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
421+ an exception, or may even cause a bad state. `x` is not checked.
422+
423+ Examples
424+ --------
425+ >>> xp.one_hot(jnp.array([1, 2, 0]), 3)
426+ Array([[0., 1., 0.],
427+ [0., 0., 1.],
428+ [1., 0., 0.]], dtype=float64)
429+ """
395430 if xp is None :
396431 xp = array_namespace (x )
397432 x_size = x .size
You can’t perform that action at this time.
0 commit comments