Skip to content

Commit 2c4c706

Browse files
authored
Skip test_dot_precision() with TF32 for Navi architecture (#520)
RDNA3 / Navi3x / gfx11xx and RDNA4 / Navi4x / gfx12xx don't support TF32/XF32 hence relevant `tests/pallas/pallas_test.py::PallasCallInterpretTest::test_dot_precision*` tests needs to be skipped.
1 parent e5fc666 commit 2c4c706

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

jax/_src/test_util.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,20 @@ def supported_dtypes():
395395
def is_device_rocm():
396396
return 'rocm' in xla_bridge.get_backend().platform_version
397397

398+
def any_rocm_device_has_gfx_prefix(gfx_prefixes: str | tuple[str])->bool:
399+
"""Returns true if any available device has gfx value that starts with any of
400+
prefixes in gfx_prefixes.
401+
Intended use is to skip tests that require certain features not present in
402+
devices."""
403+
if not is_device_rocm():
404+
return False
405+
if isinstance(gfx_prefixes, str):
406+
gfx_prefixes = (gfx_prefixes,)
407+
assert isinstance(gfx_prefixes, tuple), "argument must be a tuple of strings"
408+
assert all([isinstance(e, str) for e in gfx_prefixes])
409+
gfxs = frozenset([d.compute_capability for d in jax.devices()])
410+
return any([g.startswith(gfx_prefixes) for g in gfxs])
411+
398412
def get_rocm_version():
399413
rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm")
400414
version_path = Path(rocm_path) / ".info" / "version"

tests/pallas/pallas_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,12 @@ def test_dot_precision(self, dtype, precision):
721721
if not jtu.test_device_matches(["gpu"]):
722722
self.skipTest("`DotAlgorithmPreset` only supported on GPU.")
723723

724+
if precision in (
725+
jax.lax.DotAlgorithmPreset.TF32_TF32_F32,
726+
jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3,
727+
) and jtu.any_rocm_device_has_gfx_prefix(("gfx11","gfx12")):
728+
self.skipTest("Navi3x and Navi4x doesn't have hardware support for TF32")
729+
724730
@functools.partial(
725731
self.pallas_call,
726732
out_shape=jax.ShapeDtypeStruct((32, 64), jnp.float32),

0 commit comments

Comments
 (0)