Skip to content

Commit 0ec902d

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Add support for bf16 abs
PiperOrigin-RevId: 707041113
1 parent 0361255 commit 0ec902d

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

tests/pallas/ops_test.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,15 @@ def kernel(x_ref, o_ref):
803803
ELEMENTWISE_OPS = [
804804
(
805805
[jnp.abs, jnp.negative],
806-
["int16", "int32", "int64", "float16", "float32", "float64"],
806+
[
807+
"int16",
808+
"int32",
809+
"int64",
810+
"bfloat16",
811+
"float16",
812+
"float32",
813+
"float64",
814+
],
807815
),
808816
([jnp.ceil, jnp.floor], ["bfloat16", "float32", "float64", "int32"]),
809817
(
@@ -819,7 +827,7 @@ def kernel(x_ref, o_ref):
819827
["float32", "float64"],
820828
),
821829
([lax.population_count, lax.clz, jnp.invert], ["int32", "int64"]),
822-
([jnp.logical_not], ["bool"])
830+
([jnp.logical_not], ["bool"]),
823831
]
824832

825833
@parameterized.named_parameters(
@@ -831,8 +839,21 @@ def test_elementwise(self, fn, dtype):
831839
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
832840
self.skipTest("64-bit types require x64_enabled")
833841

834-
if jtu.test_device_matches(["tpu"]) and dtype in ("int16", "float16"):
835-
self.skipTest("int16 and float16 are not supported on TPU")
842+
if jtu.test_device_matches(["tpu"]):
843+
if dtype in ("int16", "float16"):
844+
self.skipTest("int16 and float16 are not supported on TPU")
845+
if (
846+
fn in (jnp.ceil, jnp.floor, jnp.negative)
847+
and dtype == "bfloat16"
848+
and not jtu.is_device_tpu_at_least(6)
849+
):
850+
self.skipTest(f"bfloat16 {fn.__name__} is only supported on TPU v6+")
851+
# TODO(b/370578663): implement these lowerings on TPU
852+
if fn in (
853+
jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh,
854+
jnp.cbrt, jnp.cosh, jnp.expm1, jnp.sinh,
855+
):
856+
self.skipTest(f"{fn.__name__} not implemented on TPU")
836857

837858
if (
838859
jtu.test_device_matches(["gpu"])
@@ -841,21 +862,6 @@ def test_elementwise(self, fn, dtype):
841862
):
842863
self.skipTest(f"bfloat16 {fn.__name__} is not supported on GPU")
843864

844-
if (
845-
jtu.test_device_matches(["tpu"])
846-
and not jtu.is_device_tpu_at_least(6)
847-
and fn in (jnp.ceil, jnp.floor)
848-
and dtype == "bfloat16"
849-
):
850-
self.skipTest(f"bfloat16 {fn.__name__} is only supported on TPU v6+")
851-
852-
# TODO(b/370578663): implement these lowerings on TPU
853-
if jtu.test_device_matches(["tpu"]) and fn in (
854-
jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh,
855-
jnp.cbrt, jnp.cosh, jnp.expm1, jnp.sinh,
856-
):
857-
self.skipTest(f"{fn.__name__} not implemented on TPU")
858-
859865
@functools.partial(
860866
self.pallas_call,
861867
out_shape=jax.ShapeDtypeStruct((8, 128), dtype),

0 commit comments

Comments
 (0)