Skip to content

Commit 2b7b074

Browse files
ayaka14732Google-ML-Automation
authored andcommitted
[Pallas TPU] Add lowerings for bf16 jnp.ceil and jnp.floor in TPU v6+
This PR is similar to jax-ml#24284 Note that `np.testing.assert_allclose()` is changed to `self.assertAllClose()` because the latter is a wrapper with bfloat16 support. PiperOrigin-RevId: 688581914
1 parent 2596a40 commit 2b7b074

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

tests/pallas/ops_test.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def kernel(x_ref, o_ref):
742742
[jnp.abs, jnp.negative],
743743
["int16", "int32", "int64", "float16", "float32", "float64"],
744744
),
745-
([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]),
745+
([jnp.ceil, jnp.floor], ["bfloat16", "float32", "float64", "int32"]),
746746
(
747747
[jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt],
748748
["float16", "float32", "float64"],
@@ -767,8 +767,23 @@ def test_elementwise(self, fn, dtype):
767767
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
768768
self.skipTest("64-bit types require x64_enabled")
769769

770-
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
771-
self.skipTest("16-bit types are not supported on TPU")
770+
if jtu.test_device_matches(["tpu"]) and dtype in ("int16", "float16"):
771+
self.skipTest("int16 and float16 are not supported on TPU")
772+
773+
if (
774+
jtu.test_device_matches(["gpu"])
775+
and fn in (jnp.ceil, jnp.floor)
776+
and dtype == "bfloat16"
777+
):
778+
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")
779+
780+
if (
781+
jtu.test_device_matches(["tpu"])
782+
and not jtu.is_device_tpu_at_least(6)
783+
and fn in (jnp.ceil, jnp.floor)
784+
and dtype == "bfloat16"
785+
):
786+
self.skipTest(f"bfloat16 {fn.__name__} is only supported on TPU v6+")
772787

773788
# TODO(b/370578663): implement these lowerings on TPU
774789
if jtu.test_device_matches(["tpu"]) and fn in (
@@ -784,7 +799,7 @@ def kernel(x_ref, o_ref):
784799
o_ref[:] = fn(x_ref[...])
785800

786801
x = jnp.array([0.42, 2.4]).astype(dtype)
787-
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
802+
self.assertAllClose(kernel(x), fn(x), rtol=1e-6)
788803

789804
@parameterized.named_parameters(
790805
(f"{fn.__name__}_{dtype}", fn, dtype)
@@ -798,6 +813,13 @@ def test_elementwise_scalar(self, fn, dtype):
798813
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
799814
self.skipTest("16-bit types are not supported on TPU")
800815

816+
if (
817+
jtu.test_device_matches(["gpu"])
818+
and fn in (jnp.ceil, jnp.floor)
819+
and dtype == "bfloat16"
820+
):
821+
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")
822+
801823
if (
802824
jtu.test_device_matches(["tpu"])
803825
and fn == lax.population_count
@@ -826,7 +848,7 @@ def kernel(x_ref, o_ref):
826848
o_ref[1] = fn(x_ref[1])
827849

828850
x = jnp.array([0.42, 2.4]).astype(dtype)
829-
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
851+
self.assertAllClose(kernel(x), fn(x), rtol=1e-6)
830852

831853
def test_abs_weak_type(self):
832854
# see https://github.com/jax-ml/jax/issues/23191

0 commit comments

Comments
 (0)