Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ class NNFunctionsTest(jtu.JaxTestCase):
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
)
def testScaledMatmul(self, contract, lhs_non_contract, dtype):
if not jtu.is_cuda_compute_capability_at_least("10.0"):
if not jtu.test_device_matches(["gpu"]):
raise unittest.SkipTest("Test requires GPU.")
if jtu.is_device_cuda() and not jtu.is_cuda_compute_capability_at_least("10.0"):
raise unittest.SkipTest("Needs compute capability 10.0 or higher.")
# Check if float8_e8m0fnu is available
configs = create_mxfp8_configs_if_available()
Expand All @@ -136,7 +138,9 @@ def testScaledMatmul(self, contract, lhs_non_contract, dtype):
)
def testScaledDotGeneral(
self, is_training, output_type):
if not jtu.is_cuda_compute_capability_at_least("10.0"):
if not jtu.test_device_matches(["gpu"]):
raise unittest.SkipTest("Test requires GPU.")
if jtu.is_device_cuda() and not jtu.is_cuda_compute_capability_at_least("10.0"):
raise unittest.SkipTest("Needs compute capability 10.0 or higher.")

configs = create_mxfp8_configs_if_available()
Expand Down Expand Up @@ -193,6 +197,8 @@ def fwd(a, b, is_ref=False):
impl=['cudnn', 'xla'],
)
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
if impl == 'cudnn' and jtu.is_device_rocm():
raise unittest.SkipTest("cuDNN not available on ROCm.")
if impl == 'cudnn' and not jtu.is_cuda_compute_capability_at_least("8.0"):
raise unittest.SkipTest("Needs compute capability 8.0 or higher.")
if impl == 'cudnn' and dtype == jnp.float32:
Expand Down Expand Up @@ -265,10 +271,13 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
def testDotProductAttentionMask(self, mask_mode):
if isinstance(mask_mode, str):
mask_mode = (mask_mode,)
if not jtu.is_cuda_compute_capability_at_least("8.0"):
raise unittest.SkipTest("Requires compute capability 8.0 or higher.")
if jtu.is_cuda_version_at_least(13, 0):
raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.")
# ROCm: use XLA implementation instead of cuDNN
use_cudnn = jtu.is_device_cuda()
if use_cudnn:
if not jtu.is_cuda_compute_capability_at_least("8.0"):
raise unittest.SkipTest("Requires compute capability 8.0 or higher.")
if jtu.is_cuda_version_at_least(13, 0):
raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.")

dtype = jnp.bfloat16
B, S, T, N, H = 2, 128, 128, 4, 32
Expand All @@ -295,8 +304,9 @@ def testDotProductAttentionMask(self, mask_mode):
window_size = (3, 2) if is_causal else (3, 0)

sdpa = nn.dot_product_attention
impl = 'cudnn' if use_cudnn else 'xla'
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation='cudnn')
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)

args = (Q, K, V, bias, mask)
kwargs = {'query_seq_lengths': q_seqlen, 'key_value_seq_lengths': kv_seqlen}
Expand All @@ -315,9 +325,10 @@ def testDotProductAttentionMask(self, mask_mode):
dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4]
dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4]

# Check if cudnn backend is called.
self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs))
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))
# Check if cudnn backend is called (only on CUDA).
if use_cudnn:
self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs))
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))

self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.02, atol=.02)
Expand All @@ -330,10 +341,13 @@ def testDotProductAttentionMask(self, mask_mode):
use_vmap=[False, True],
)
def testDotProductAttentionBiasGradient(self, batch_size, use_vmap):
if not jtu.is_cuda_compute_capability_at_least("8.0"):
raise unittest.SkipTest("Requires compute capability 8.0 or higher.")
if jtu.is_cuda_version_at_least(13, 0):
raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.")
# ROCm: use XLA implementation instead of cuDNN
use_cudnn = jtu.is_device_cuda()
if use_cudnn:
if not jtu.is_cuda_compute_capability_at_least("8.0"):
raise unittest.SkipTest("Requires compute capability 8.0 or higher.")
if jtu.is_cuda_version_at_least(13, 0):
raise unittest.SkipTest("cuDNN creates no execution plans on CUDA 13.0.")

dtype = jnp.bfloat16
B, S, N, H = batch_size, 128, 4, 32
Expand All @@ -353,7 +367,7 @@ def attention(x, bias, mask, impl):
implementation=impl,
)
attn_ref = partial(attention, impl=None)
attn_ans = partial(attention, impl='cudnn')
attn_ans = partial(attention, impl='cudnn' if use_cudnn else 'xla')
if use_vmap:
attn_batched_ref = jax.vmap(attn_ref, in_axes=(0, 0, None))
attn_batched_ans = jax.vmap(attn_ans, in_axes=(0, 0, None))
Expand Down