diff --git a/tests/ffi_test.py b/tests/ffi_test.py index d6270d14e5d6..28415e4d5db4 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -276,10 +276,12 @@ def test_invalid_result_type(self): @jtu.run_on_devices("gpu", "cpu") def test_shard_map(self): - if jtu.is_device_rocm: - self.skipTest("Skip on ROCm: tests/ffi_test.py::FfiTest::test_shard_map") + # if jtu.is_device_rocm: + # self.skipTest("Skip on ROCm: tests/ffi_test.py::FfiTest::test_shard_map") mesh = jtu.create_mesh((len(jax.devices()),), ("i",)) x = self.rng().randn(8, 4, 5).astype(np.float32) + n = len(jax.devices()) + x = x[:(x.shape[0] // n) * n] @partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i")) def f(x):