Skip to content

Commit 8402a98

Browse files
Merge pull request #25590 from jakevdp:fix-one-hot-float
PiperOrigin-RevId: 707922981
2 parents 23000a3 + 8c3c441 commit 8402a98

File tree

5 files changed

+34
-15
lines changed

5 files changed

+34
-15
lines changed

jax/_src/deprecations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
125125
# always registered by the time `accelerate` and `is_acelerated` are called.
126126
register('jax-aval-named-shape')
127127
register('jax-dlpack-import-legacy')
128+
register('jax-nn-one-hot-float-input')
128129
register("jax-numpy-astype-complex-to-real")
129130
register("jax-numpy-array-none")
130131
register('jax-numpy-clip-args')

jax/_src/nn/functions.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from jax import lax
3030
from jax._src import config
3131
from jax._src import core
32+
from jax._src import deprecations
3233
from jax._src import dtypes
3334
from jax._src import util
3435
from jax._src.core import AxisName
@@ -645,34 +646,33 @@ def standardize(x: ArrayLike,
645646

646647
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
647648
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
648-
def _one_hot(x: Any, num_classes: int, *,
649+
def _one_hot(x: Array, num_classes: int, *,
649650
dtype: Any, axis: int | AxisName) -> Array:
650651
num_classes = core.concrete_dim_or_error(
651652
num_classes,
652653
"The error arose in jax.nn.one_hot argument `num_classes`.")
653654
dtype = dtypes.canonicalize_dtype(dtype)
654-
x_arr = jnp.asarray(x)
655655
try:
656-
output_pos_axis = util.canonicalize_axis(axis, x_arr.ndim + 1)
656+
output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
657657
except TypeError:
658658
axis_size = lax.psum(1, axis)
659659
if num_classes != axis_size:
660660
raise ValueError(f"Expected num_classes to match the size of axis {axis}, "
661661
f"but {num_classes} != {axis_size}") from None
662662
axis_idx = lax.axis_index(axis)
663-
return jnp.asarray(x_arr == axis_idx, dtype=dtype)
663+
return jnp.asarray(_dot_product_attention_xla == axis_idx, dtype=dtype)
664664
axis = operator.index(axis) # type: ignore[arg-type]
665-
lhs = lax.expand_dims(x_arr, (axis,))
666-
rhs_shape = [1] * x_arr.ndim
665+
lhs = lax.expand_dims(x, (axis,))
666+
rhs_shape = [1] * x.ndim
667667
rhs_shape.insert(output_pos_axis, num_classes)
668668
if config.sharding_in_types.value:
669669
# TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too?
670-
rhs_sharding = NamedSharding(x_arr.sharding.mesh, P(*[None] * len(rhs_shape)))
670+
rhs_sharding = NamedSharding(x.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
671671
else:
672672
rhs_sharding = None
673-
rhs = lax.broadcasted_iota(x_arr.dtype, rhs_shape, output_pos_axis,
673+
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
674674
_sharding=rhs_sharding)
675-
return jnp.asarray(lhs == rhs, dtype=dtype)
675+
return (lhs == rhs).astype(dtype)
676676

677677
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
678678
def one_hot(x: Any, num_classes: int, *,
@@ -703,7 +703,15 @@ def one_hot(x: Any, num_classes: int, *,
703703
num_classes = core.concrete_dim_or_error(
704704
num_classes,
705705
"The error arose in jax.nn.one_hot argument `num_classes`.")
706-
return _one_hot(x, num_classes, dtype=dtype, axis=axis)
706+
x_arr = jnp.asarray(x)
707+
if not jnp.isdtype(x_arr.dtype, "integral"):
708+
# Deprecated 2024-12-18
709+
deprecations.warn(
710+
'jax-nn-one-hot-float-input',
711+
f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}",
712+
stacklevel=1)
713+
x_arr = x_arr.astype('int32')
714+
return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis)
707715

708716

709717
@jax.custom_jvp

jax/experimental/jax2tf/tests/shape_poly_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,11 +1959,11 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10]
19591959
expect_error=expect_error_associative_scan),
19601960
PolyHarness("one_hot", "poly_num_classes",
19611961
lambda x, y: jax.nn.one_hot(x, y.shape[0]),
1962-
arg_descriptors=[np.arange(16, dtype=_f32), RandArg((16,), _f32)],
1962+
arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)],
19631963
polymorphic_shapes=[None, "b0, ..."]),
19641964
PolyHarness("one_hot", "all_poly",
19651965
lambda x, y: jax.nn.one_hot(x, y.shape[0]),
1966-
arg_descriptors=[np.arange(16, dtype=_f32), RandArg((16,), _f32)],
1966+
arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)],
19671967
polymorphic_shapes=["b, ...", "b, ..."]),
19681968
PolyHarness("ones", "",
19691969
lambda x: jnp.ones(x.shape, dtype=_f32) + x,

tests/nn_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424

2525
import scipy.stats
2626

27+
from jax._src import ad_checkpoint
2728
from jax._src import config
2829
from jax._src import core
30+
from jax._src import deprecations
2931
from jax._src import test_util as jtu
30-
from jax._src import ad_checkpoint
3132
from jax._src.interpreters import mlir
3233
from jax._src.lib import cuda_versions
3334
from jax.test_util import check_grads
@@ -530,6 +531,15 @@ def testOneHotAxis(self):
530531
actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
531532
self.assertAllClose(actual, expected, check_dtypes=False)
532533

534+
def testOneHotNonInteger(self):
535+
def assert_warns_or_errors(msg):
536+
if deprecations.is_accelerated("jax-nn-one-hot-float-input"):
537+
return self.assertRaisesRegex(ValueError, msg)
538+
else:
539+
return self.assertWarnsRegex(DeprecationWarning, msg)
540+
with assert_warns_or_errors("jax.nn.one_hot input should be integer-typed"):
541+
nn.one_hot(jnp.array([1.0]), 3)
542+
533543
def testTanhExists(self):
534544
nn.tanh # doesn't crash
535545

tests/shape_poly_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2760,11 +2760,11 @@ def f(x_ref):
27602760
expect_error=expect_error_associative_scan),
27612761
PolyHarness("one_hot", "poly_num_classes",
27622762
lambda x, y: jax.nn.one_hot(x, y.shape[0]),
2763-
arg_descriptors=[np.arange(16, dtype=_f32), RandArg((16,), _f32)],
2763+
arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)],
27642764
polymorphic_shapes=[None, "b0, ..."]),
27652765
PolyHarness("one_hot", "all_poly",
27662766
lambda x, y: jax.nn.one_hot(x, y.shape[0]),
2767-
arg_descriptors=[np.arange(16, dtype=_f32), RandArg((16,), _f32)],
2767+
arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)],
27682768
polymorphic_shapes=["b, ...", "b, ..."]),
27692769
PolyHarness("ones", "",
27702770
lambda x: jnp.ones(x.shape, dtype=_f32) + x,

0 commit comments

Comments
 (0)