@@ -175,10 +175,10 @@ def test_fused_attention_fwd(
175175 use_fwd ,
176176 use_segment_ids ,
177177 ):
178- if jtu .is_device_rocm and 'gfx950' in [d .compute_capability for d in jax .devices ()]:
178+ if jtu .is_device_rocm () and 'gfx950' in [d .compute_capability for d in jax .devices ()]:
179179 self .skipTest ("Skip on ROCm: test_fused_attention_fwd: LLVM ERROR: Do not know how to scalarize the result of this operator!" )
180180
181- if jtu .is_device_rocm and batch_size == 2 and seq_len == 384 and num_heads == 8 and head_dim == 64 and block_sizes == (('block_q' , 128 ), ('block_k' , 128 )) and causal and use_fwd and use_segment_ids :
181+ if jtu .is_device_rocm () and batch_size == 2 and seq_len == 384 and num_heads == 8 and head_dim == 64 and block_sizes == (('block_q' , 128 ), ('block_k' , 128 )) and causal and use_fwd and use_segment_ids :
182182 self .skipTest ("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd4" )
183183 k1 , k2 , k3 = random .split (random .key (0 ), 3 )
184184 q = random .normal (
@@ -232,7 +232,7 @@ def impl(q, k, v):
232232 # RESOURCE_EXHAUSTED: Shared memory size limit exceeded" error.
233233 @jtu .sample_product (
234234 batch_size = (1 , 2 ),
235- seq_len = (32 , 64 ) if jtu .is_device_rocm else (128 , 384 ),
235+ seq_len = (32 , 64 ) if jtu .is_device_rocm () else (128 , 384 ),
236236 num_heads = (1 , 2 ),
237237 head_dim = (32 , 64 , 128 ,),
238238 block_sizes = (
@@ -253,7 +253,7 @@ def impl(q, k, v):
253253 ("block_kv_dq" , 32 ),
254254 ),
255255 )
256- if jtu .is_device_rocm else (
256+ if jtu .is_device_rocm () else (
257257 (
258258 ("block_q" , 128 ),
259259 ("block_k" , 128 ),
@@ -295,9 +295,9 @@ def test_fused_attention_bwd(
295295 ):
296296 test_name = str (self ).split ()[0 ]
297297 skip_suffix_list = [4 , 6 , 7 , 8 , 9 ]
298- if jtu .is_device_rocm and 'gfx950' in [d .compute_capability for d in jax .devices ()]:
298+ if jtu .is_device_rocm () and 'gfx950' in [d .compute_capability for d in jax .devices ()]:
299299 self .skipTest ("Skip on ROCm: test_fused_attention_bwd: LLVM ERROR: Do not know how to scalarize the result of this operator!" )
300- if jtu .is_device_rocm and self .INTERPRET and any (test_name .endswith (str (suffix )) for suffix in skip_suffix_list ):
300+ if jtu .is_device_rocm () and self .INTERPRET and any (test_name .endswith (str (suffix )) for suffix in skip_suffix_list ):
301301 self .skipTest ("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd[4, 6, 7, 8, 9]" )
302302
303303 k1 , k2 , k3 = random .split (random .key (0 ), 3 )
0 commit comments