|
29 | 29 | from jax import lax |
30 | 30 | from jax._src import config |
31 | 31 | from jax._src import core |
| 32 | +from jax._src import deprecations |
32 | 33 | from jax._src import dtypes |
33 | 34 | from jax._src import util |
34 | 35 | from jax._src.core import AxisName |
@@ -645,34 +646,33 @@ def standardize(x: ArrayLike, |
645 | 646 |
|
646 | 647 | # TODO(slebedev): Change the type of `x` to `ArrayLike`. |
647 | 648 | @partial(jax.jit, static_argnames=("num_classes", "dtype", "axis")) |
648 | | -def _one_hot(x: Any, num_classes: int, *, |
| 649 | +def _one_hot(x: Array, num_classes: int, *, |
649 | 650 | dtype: Any, axis: int | AxisName) -> Array: |
650 | 651 | num_classes = core.concrete_dim_or_error( |
651 | 652 | num_classes, |
652 | 653 | "The error arose in jax.nn.one_hot argument `num_classes`.") |
653 | 654 | dtype = dtypes.canonicalize_dtype(dtype) |
654 | | - x_arr = jnp.asarray(x) |
655 | 655 | try: |
656 | | - output_pos_axis = util.canonicalize_axis(axis, x_arr.ndim + 1) |
| 656 | + output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) |
657 | 657 | except TypeError: |
658 | 658 | axis_size = lax.psum(1, axis) |
659 | 659 | if num_classes != axis_size: |
660 | 660 | raise ValueError(f"Expected num_classes to match the size of axis {axis}, " |
661 | 661 | f"but {num_classes} != {axis_size}") from None |
662 | 662 | axis_idx = lax.axis_index(axis) |
663 | | - return jnp.asarray(x_arr == axis_idx, dtype=dtype) |
| 663 | + return jnp.asarray(_dot_product_attention_xla == axis_idx, dtype=dtype) |
664 | 664 | axis = operator.index(axis) # type: ignore[arg-type] |
665 | | - lhs = lax.expand_dims(x_arr, (axis,)) |
666 | | - rhs_shape = [1] * x_arr.ndim |
| 665 | + lhs = lax.expand_dims(x, (axis,)) |
| 666 | + rhs_shape = [1] * x.ndim |
667 | 667 | rhs_shape.insert(output_pos_axis, num_classes) |
668 | 668 | if config.sharding_in_types.value: |
669 | 669 | # TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too? |
670 | | - rhs_sharding = NamedSharding(x_arr.sharding.mesh, P(*[None] * len(rhs_shape))) |
| 670 | + rhs_sharding = NamedSharding(x.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error |
671 | 671 | else: |
672 | 672 | rhs_sharding = None |
673 | | - rhs = lax.broadcasted_iota(x_arr.dtype, rhs_shape, output_pos_axis, |
| 673 | + rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis, |
674 | 674 | _sharding=rhs_sharding) |
675 | | - return jnp.asarray(lhs == rhs, dtype=dtype) |
| 675 | + return (lhs == rhs).astype(dtype) |
676 | 676 |
|
677 | 677 | # TODO(slebedev): Change the type of `x` to `ArrayLike`. |
678 | 678 | def one_hot(x: Any, num_classes: int, *, |
@@ -703,7 +703,15 @@ def one_hot(x: Any, num_classes: int, *, |
703 | 703 | num_classes = core.concrete_dim_or_error( |
704 | 704 | num_classes, |
705 | 705 | "The error arose in jax.nn.one_hot argument `num_classes`.") |
706 | | - return _one_hot(x, num_classes, dtype=dtype, axis=axis) |
| 706 | + x_arr = jnp.asarray(x) |
| 707 | + if not jnp.isdtype(x_arr.dtype, "integral"): |
| 708 | + # Deprecated 2024-12-18 |
| 709 | + deprecations.warn( |
| 710 | + 'jax-nn-one-hot-float-input', |
| 711 | + f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}", |
| 712 | + stacklevel=1) |
| 713 | + x_arr = x_arr.astype('int32') |
| 714 | + return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis) |
707 | 715 |
|
708 | 716 |
|
709 | 717 | @jax.custom_jvp |
|
0 commit comments