@@ -113,7 +113,8 @@ class NNFunctionsTest(jtu.JaxTestCase):
113113 dtype = [jnp .float16 , jnp .bfloat16 , jnp .float32 ],
114114 )
115115 def testScaledMatmul (self , contract , lhs_non_contract , dtype ):
116- if not jtu .is_cuda_compute_capability_at_least ("10.0" ):
116+ # ROCm: scaled_matmul works, skip only CUDA compute capability check
117+ if not jtu .is_device_rocm () and not jtu .is_cuda_compute_capability_at_least ("10.0" ):
117118 raise unittest .SkipTest ("Needs compute capability 10.0 or higher." )
118119 # Check if float8_e8m0fnu is available
119120 configs = create_mxfp8_configs_if_available ()
@@ -136,7 +137,8 @@ def testScaledMatmul(self, contract, lhs_non_contract, dtype):
136137 )
137138 def testScaledDotGeneral (
138139 self , is_training , output_type ):
139- if not jtu .is_cuda_compute_capability_at_least ("10.0" ):
140+ # ROCm: scaled_dot_general works, skip only CUDA compute capability check
141+ if not jtu .is_device_rocm () and not jtu .is_cuda_compute_capability_at_least ("10.0" ):
140142 raise unittest .SkipTest ("Needs compute capability 10.0 or higher." )
141143
142144 configs = create_mxfp8_configs_if_available ()
@@ -193,6 +195,8 @@ def fwd(a, b, is_ref=False):
193195 impl = ['cudnn' , 'xla' ],
194196 )
195197 def testDotProductAttention (self , dtype , group_num , use_vmap , impl ):
198+ if impl == 'cudnn' and jtu .is_device_rocm ():
199+ raise unittest .SkipTest ("cuDNN not available on ROCm." )
196200 if impl == 'cudnn' and not jtu .is_cuda_compute_capability_at_least ("8.0" ):
197201 raise unittest .SkipTest ("Needs compute capability 8.0 or higher." )
198202 if impl == 'cudnn' and dtype == jnp .float32 :
@@ -265,10 +269,13 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
265269 def testDotProductAttentionMask (self , mask_mode ):
266270 if isinstance (mask_mode , str ):
267271 mask_mode = (mask_mode ,)
268- if not jtu .is_cuda_compute_capability_at_least ("8.0" ):
269- raise unittest .SkipTest ("Requires compute capability 8.0 or higher." )
270- if jtu .is_cuda_version_at_least (13 , 0 ):
271- raise unittest .SkipTest ("cuDNN creates no execution plans on CUDA 13.0." )
272+ # ROCm: use XLA implementation instead of cuDNN
273+ use_cudnn = not jtu .is_device_rocm ()
274+ if use_cudnn :
275+ if not jtu .is_cuda_compute_capability_at_least ("8.0" ):
276+ raise unittest .SkipTest ("Requires compute capability 8.0 or higher." )
277+ if jtu .is_cuda_version_at_least (13 , 0 ):
278+ raise unittest .SkipTest ("cuDNN creates no execution plans on CUDA 13.0." )
272279
273280 dtype = jnp .bfloat16
274281 B , S , T , N , H = 2 , 128 , 128 , 4 , 32
@@ -295,8 +302,9 @@ def testDotProductAttentionMask(self, mask_mode):
295302 window_size = (3 , 2 ) if is_causal else (3 , 0 )
296303
297304 sdpa = nn .dot_product_attention
305+ impl = 'cudnn' if use_cudnn else 'xla'
298306 sdpa_ref = partial (sdpa , is_causal = is_causal , implementation = None )
299- sdpa_ans = partial (sdpa , is_causal = is_causal , implementation = 'cudnn' )
307+ sdpa_ans = partial (sdpa , is_causal = is_causal , implementation = impl )
300308
301309 args = (Q , K , V , bias , mask )
302310 kwargs = {'query_seq_lengths' : q_seqlen , 'key_value_seq_lengths' : kv_seqlen }
@@ -315,9 +323,10 @@ def testDotProductAttentionMask(self, mask_mode):
315323 dQ_ref , dK_ref , dV_ref , dbias_ref = sdpa_vjp_ref (grad )[:4 ]
316324 dQ_ans , dK_ans , dV_ans , dbias_ans = sdpa_vjp_ans (grad )[:4 ]
317325
318- # Check if cudnn backend is called.
319- self .assertTrue (_check_cudnn_backend (sdpa_ans , * args , ** kwargs ))
320- self .assertTrue (_check_cudnn_backend (sdpa_vjp_ans , grad ))
326+ # Check if cudnn backend is called (only on CUDA).
327+ if use_cudnn :
328+ self .assertTrue (_check_cudnn_backend (sdpa_ans , * args , ** kwargs ))
329+ self .assertTrue (_check_cudnn_backend (sdpa_vjp_ans , grad ))
321330
322331 self .assertAllClose (out_ref , out_ans , atol = .01 , rtol = .01 )
323332 self .assertAllClose (dQ_ref , dQ_ans , rtol = .02 , atol = .02 )
@@ -330,10 +339,13 @@ def testDotProductAttentionMask(self, mask_mode):
330339 use_vmap = [False , True ],
331340 )
332341 def testDotProductAttentionBiasGradient (self , batch_size , use_vmap ):
333- if not jtu .is_cuda_compute_capability_at_least ("8.0" ):
334- raise unittest .SkipTest ("Requires compute capability 8.0 or higher." )
335- if jtu .is_cuda_version_at_least (13 , 0 ):
336- raise unittest .SkipTest ("cuDNN creates no execution plans on CUDA 13.0." )
342+ # ROCm: use XLA implementation instead of cuDNN
343+ use_cudnn = not jtu .is_device_rocm ()
344+ if use_cudnn :
345+ if not jtu .is_cuda_compute_capability_at_least ("8.0" ):
346+ raise unittest .SkipTest ("Requires compute capability 8.0 or higher." )
347+ if jtu .is_cuda_version_at_least (13 , 0 ):
348+ raise unittest .SkipTest ("cuDNN creates no execution plans on CUDA 13.0." )
337349
338350 dtype = jnp .bfloat16
339351 B , S , N , H = batch_size , 128 , 4 , 32
@@ -353,7 +365,7 @@ def attention(x, bias, mask, impl):
353365 implementation = impl ,
354366 )
355367 attn_ref = partial (attention , impl = None )
356- attn_ans = partial (attention , impl = 'cudnn' )
368+ attn_ans = partial (attention , impl = 'cudnn' if use_cudnn else 'xla' )
357369 if use_vmap :
358370 attn_batched_ref = jax .vmap (attn_ref , in_axes = (0 , 0 , None ))
359371 attn_batched_ans = jax .vmap (attn_ans , in_axes = (0 , 0 , None ))
0 commit comments