Skip to content

Commit 41a32e3

Browse files
unskip pallas tests that were skipped due to LLVm error
1 parent 248cf43 commit 41a32e3

File tree

3 files changed

+0
-10
lines changed

3 files changed

+0
-10
lines changed

tests/pallas/gpu_attention_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,6 @@ def test_mqa(
104104
kv_seq_len,
105105
return_residuals,
106106
):
107-
if jtu.is_device_rocm() and 'gfx950' in [d.compute_capability for d in jax.devices()]:
108-
self.skipTest("Skip on ROCm: test_mqa: LLVM ERROR: Do not know how to scalarize the result of this operator!")
109107
del kwargs
110108
normalize_output = not return_residuals
111109

@@ -185,8 +183,6 @@ def test_gqa(
185183
kv_seq_len,
186184
return_residuals,
187185
):
188-
if jtu.is_device_rocm() and 'gfx950' in [d.compute_capability for d in jax.devices()]:
189-
self.skipTest("Skip on ROCm: test_gqa: LLVM ERROR: Do not know how to scalarize the result of this operator!")
190186
del kwargs
191187
normalize_output = not return_residuals
192188

tests/pallas/gpu_ops_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,6 @@ def test_fused_attention_fwd(
175175
use_fwd,
176176
use_segment_ids,
177177
):
178-
if jtu.is_device_rocm() and 'gfx950' in [d.compute_capability for d in jax.devices()]:
179-
self.skipTest("Skip on ROCm: test_fused_attention_fwd: LLVM ERROR: Do not know how to scalarize the result of this operator!")
180178

181179
if jtu.is_device_rocm() and batch_size == 2 and seq_len == 384 and num_heads == 8 and head_dim == 64 and block_sizes == (('block_q', 128), ('block_k', 128)) and causal and use_fwd and use_segment_ids:
182180
self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd4")
@@ -295,8 +293,6 @@ def test_fused_attention_bwd(
295293
):
296294
test_name = str(self).split()[0]
297295
skip_suffix_list = [4, 6, 7, 8, 9]
298-
if jtu.is_device_rocm() and 'gfx950' in [d.compute_capability for d in jax.devices()]:
299-
self.skipTest("Skip on ROCm: test_fused_attention_bwd: LLVM ERROR: Do not know how to scalarize the result of this operator!")
300296
if jtu.is_device_rocm() and self.INTERPRET and any(test_name.endswith(str(suffix)) for suffix in skip_suffix_list):
301297
self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd[4, 6, 7, 8, 9]")
302298

tests/pallas/gpu_paged_attention_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ def test_paged_attention(
122122
k_splits,
123123
attn_logits_soft_cap,
124124
):
125-
if jtu.is_device_rocm() and 'gfx950' in [d.compute_capability for d in jax.devices()]:
126-
self.skipTest("Skip on ROCm: test_paged_attention: LLVM ERROR: Do not know how to scalarize the result of this operator!")
127125

128126
test_name = str(self).split()[0]
129127
skip_numbers = {0, 1, 3, 5, 6, 7, 9}

0 commit comments

Comments
 (0)