Skip to content

Commit e5fc666

Browse files
rootgulsumgudukbay
authored andcommitted
adding additional skips only for MI350 for LLVM cannot scalarize ERROR
1 parent 9c088fc commit e5fc666

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

tests/pallas/gpu_attention_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def test_mqa(
104104
kv_seq_len,
105105
return_residuals,
106106
):
107+
if jtu.is_device_rocm and 'gfx950' in [d.compute_capability for d in jax.devices()]:
108+
self.skipTest("Skip on ROCm: test_mqa: LLVM ERROR: Do not know how to scalarize the result of this operator!")
107109
del kwargs
108110
normalize_output = not return_residuals
109111

@@ -183,6 +185,8 @@ def test_gqa(
183185
kv_seq_len,
184186
return_residuals,
185187
):
188+
if jtu.is_device_rocm and 'gfx950' in [d.compute_capability for d in jax.devices()]:
189+
self.skipTest("Skip on ROCm: test_gqa: LLVM ERROR: Do not know how to scalarize the result of this operator!")
186190
del kwargs
187191
normalize_output = not return_residuals
188192

tests/pallas/gpu_ops_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def test_fused_attention_fwd(
175175
use_fwd,
176176
use_segment_ids,
177177
):
178+
if jtu.is_device_rocm and 'gfx950' in [d.compute_capability for d in jax.devices()]:
179+
self.skipTest("Skip on ROCm: test_fused_attention_fwd: LLVM ERROR: Do not know how to scalarize the result of this operator!")
180+
178181
if jtu.is_device_rocm and batch_size == 2 and seq_len == 384 and num_heads == 8 and head_dim == 64 and block_sizes == (('block_q', 128), ('block_k', 128)) and causal and use_fwd and use_segment_ids:
179182
self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd4")
180183
k1, k2, k3 = random.split(random.key(0), 3)
@@ -292,7 +295,8 @@ def test_fused_attention_bwd(
292295
):
293296
test_name = str(self).split()[0]
294297
skip_suffix_list = [4, 6, 7, 8, 9]
295-
298+
if jtu.is_device_rocm and 'gfx950' in [d.compute_capability for d in jax.devices()]:
299+
self.skipTest("Skip on ROCm: test_fused_attention_bwd: LLVM ERROR: Do not know how to scalarize the result of this operator!")
296300
if jtu.is_device_rocm and self.INTERPRET and any(test_name.endswith(str(suffix)) for suffix in skip_suffix_list):
297301
self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd[4, 6, 7, 8, 9]")
298302

tests/pallas/gpu_paged_attention_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def test_paged_attention(
122122
k_splits,
123123
attn_logits_soft_cap,
124124
):
125+
if jtu.is_device_rocm and 'gfx950' in [d.compute_capability for d in jax.devices()]:
126+
self.skipTest("Skip on ROCm: test_paged_attention: LLVM ERROR: Do not know how to scalarize the result of this operator!")
127+
125128
test_name = str(self).split()[0]
126129
skip_numbers = {0, 1, 3, 5, 6, 7, 9}
127130
if jtu.is_device_rocm and test_name in {f"test_paged_attention{i}" for i in skip_numbers}:

0 commit comments

Comments
 (0)