Skip to content

Commit 3cf7490

Browse files
committed
Enable neural network tests on ROCm
Enable NN tests to run on ROCm by using XLA implementation instead of cuDNN (which is NVIDIA-only) and fixing compute capability skip checks. Changes: - testScaledMatmul: skip compute capability check on ROCm (works on ROCm) - testScaledDotGeneral: skip compute capability check on ROCm (works on ROCm) - testDotProductAttention: add ROCm skip for cuDNN impl (XLA impl still runs) - testDotProductAttentionMask: use XLA instead of cuDNN on ROCm - testDotProductAttentionBiasGradient: use XLA instead of cuDNN on ROCm Cherry-picked from ROCm/jax PRs #604, #637, #640
1 parent 09e023e commit 3cf7490

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

tests/nn_test.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ class NNFunctionsTest(jtu.JaxTestCase):
113113
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
114114
)
115115
def testScaledMatmul(self, contract, lhs_non_contract, dtype):
116-
if not jtu.is_cuda_compute_capability_at_least("10.0"):
116+
# ROCm: scaled_matmul works, skip only CUDA compute capability check
117+
if not jtu.is_device_rocm() and not jtu.is_cuda_compute_capability_at_least("10.0"):
117118
raise unittest.SkipTest("Needs compute capability 10.0 or higher.")
118119
# Check if float8_e8m0fnu is available
119120
configs = create_mxfp8_configs_if_available()
@@ -136,7 +137,8 @@ def testScaledMatmul(self, contract, lhs_non_contract, dtype):
136137
)
137138
def testScaledDotGeneral(
138139
self, is_training, output_type):
139-
if not jtu.is_cuda_compute_capability_at_least("10.0"):
140+
# ROCm: scaled_dot_general works, skip only CUDA compute capability check
141+
if not jtu.is_device_rocm() and not jtu.is_cuda_compute_capability_at_least("10.0"):
140142
raise unittest.SkipTest("Needs compute capability 10.0 or higher.")
141143

142144
configs = create_mxfp8_configs_if_available()
@@ -193,6 +195,8 @@ def fwd(a, b, is_ref=False):
193195
impl=['cudnn', 'xla'],
194196
)
195197
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
198+
if impl == 'cudnn' and jtu.is_device_rocm():
199+
raise unittest.SkipTest("cuDNN not available on ROCm.")
196200
if impl == 'cudnn' and not jtu.is_cuda_compute_capability_at_least("8.0"):
197201
raise unittest.SkipTest("Needs compute capability 8.0 or higher.")
198202
if impl == 'cudnn' and dtype == jnp.float32:
@@ -265,10 +269,13 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
265269
def testDotProductAttentionMask(self, mask_mode):
266270
if isinstance(mask_mode, str):
267271
mask_mode = (mask_mode,)
268-
if not jtu.is_cuda_compute_capability_at_least("8.0"):
269-
raise unittest.SkipTest("Requires compute capability 8.0 or higher.")
270-
if jtu.is_cuda_version_at_least(13, 0):
271-
raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.")
272+
# ROCm: use XLA implementation instead of cuDNN
273+
use_cudnn = not jtu.is_device_rocm()
274+
if use_cudnn:
275+
if not jtu.is_cuda_compute_capability_at_least("8.0"):
276+
raise unittest.SkipTest("Requires compute capability 8.0 or higher.")
277+
if jtu.is_cuda_version_at_least(13, 0):
278+
raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.")
272279

273280
dtype = jnp.bfloat16
274281
B, S, T, N, H = 2, 128, 128, 4, 32
@@ -295,8 +302,9 @@ def testDotProductAttentionMask(self, mask_mode):
295302
window_size = (3, 2) if is_causal else (3, 0)
296303

297304
sdpa = nn.dot_product_attention
305+
impl = 'cudnn' if use_cudnn else 'xla'
298306
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
299-
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation='cudnn')
307+
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
300308

301309
args = (Q, K, V, bias, mask)
302310
kwargs = {'query_seq_lengths': q_seqlen, 'key_value_seq_lengths': kv_seqlen}
@@ -315,9 +323,10 @@ def testDotProductAttentionMask(self, mask_mode):
315323
dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4]
316324
dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4]
317325

318-
# Check if cudnn backend is called.
319-
self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs))
320-
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))
326+
# Check if cudnn backend is called (only on CUDA).
327+
if use_cudnn:
328+
self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs))
329+
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))
321330

322331
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
323332
self.assertAllClose(dQ_ref, dQ_ans, rtol=.02, atol=.02)
@@ -330,10 +339,13 @@ def testDotProductAttentionMask(self, mask_mode):
330339
use_vmap=[False, True],
331340
)
332341
def testDotProductAttentionBiasGradient(self, batch_size, use_vmap):
333-
if not jtu.is_cuda_compute_capability_at_least("8.0"):
334-
raise unittest.SkipTest("Requires compute capability 8.0 or higher.")
335-
if jtu.is_cuda_version_at_least(13, 0):
336-
raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.")
342+
# ROCm: use XLA implementation instead of cuDNN
343+
use_cudnn = not jtu.is_device_rocm()
344+
if use_cudnn:
345+
if not jtu.is_cuda_compute_capability_at_least("8.0"):
346+
raise unittest.SkipTest("Requires compute capability 8.0 or higher.")
347+
if jtu.is_cuda_version_at_least(13, 0):
348+
raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.")
337349

338350
dtype = jnp.bfloat16
339351
B, S, N, H = batch_size, 128, 4, 32
@@ -353,7 +365,7 @@ def attention(x, bias, mask, impl):
353365
implementation=impl,
354366
)
355367
attn_ref = partial(attention, impl=None)
356-
attn_ans = partial(attention, impl='cudnn')
368+
attn_ans = partial(attention, impl='cudnn' if use_cudnn else 'xla')
357369
if use_vmap:
358370
attn_batched_ref = jax.vmap(attn_ref, in_axes=(0, 0, None))
359371
attn_batched_ans = jax.vmap(attn_ans, in_axes=(0, 0, None))

0 commit comments

Comments
 (0)