Skip to content

Commit 613a0cd

Browse files
ayaka14732Google-ML-Automation
authored andcommitted
[Pallas] Fix lowering tests for reduction ops
Remove unnecessary skip statements. Also added tests for bf16 types. PiperOrigin-RevId: 707739536
1 parent f65eced commit 613a0cd

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

tests/pallas/ops_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,35 +1895,35 @@ def reduce(x_ref, y_ref):
18951895
for axis in [0, 1, (1,), (0, 1)]
18961896
for dtype in [
18971897
"float16",
1898+
"bfloat16",
18981899
"float32",
18991900
"float64",
19001901
"int32",
19011902
"int64",
19021903
"uint32",
19031904
"uint64",
19041905
]
1905-
if isinstance(axis, int) or "arg" not in op_name
19061906
])
19071907
def test_array_reduce(self, op, dtype, axis):
1908-
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
1909-
self.skipTest("16-bit types are not supported on TPU")
1908+
if not isinstance(axis, int):
1909+
self.skipTest("TODO: tuple axes are not yet supported")
19101910

19111911
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
19121912
self.skipTest("64-bit types require x64_enabled")
19131913

1914+
# The Pallas TPU lowering currently supports only blocks of rank >= 1
1915+
if jtu.test_device_matches(["tpu"]):
1916+
self.skipTest("Not implemented on TPU")
1917+
19141918
# Skip argmin/argmax on GPU in 64-bit mode because Pallas expects
19151919
# `index_type` to be i32
19161920
if (
19171921
jax.config.x64_enabled
19181922
and jtu.test_device_matches(["gpu"])
1919-
and op in {jnp.argmin, jnp.argmax}
1923+
and op in (jnp.argmin, jnp.argmax)
19201924
):
19211925
self.skipTest("Not supported on GPU in 64-bit mode")
19221926

1923-
# The Pallas TPU lowering currently supports only blocks of rank >= 1
1924-
if jtu.test_device_matches(["tpu"]):
1925-
self.skipTest("Not supported on TPU")
1926-
19271927
m, n = 32, 8
19281928

19291929
def make_x(key):
@@ -1955,7 +1955,7 @@ def reduce(x_ref, y_ref):
19551955
x = make_x(key)
19561956
y = reduce(x)
19571957
y_ref = op(x, axis=axis)
1958-
np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i)
1958+
self.assertAllClose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i)
19591959

19601960
@parameterized.product(
19611961
axis=[0, 1],

0 commit comments

Comments
 (0)