Skip to content

Commit 13f8d31

Browse files
authored
Run CI on mi325x (#441)
1 parent 0b7cc5f commit 13f8d31

10 files changed

+58
-0
lines changed

.github/matrix.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@
5353
"runtime-version": "cu129",
5454
"container-options": "--gpus all",
5555
"alias": "b200"
56+
},
57+
{
58+
"runner": "linux.rocm.gpu.gfx942.2",
59+
"python-version": "3.12",
60+
"ref-mode": "none",
61+
"image": "rocm/dev-ubuntu-24.04:6.2.4",
62+
"runtime-version": "rocm6.4",
63+
"container-options": "--device=/dev/kfd --device=/dev/dri",
64+
"alias": "mi325x"
5665
}
5766
]
5867
}

helion/_testing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def skipIfNormalMode(reason: str) -> Callable[[Callable], Callable]:
4545
return unittest.skipIf(os.environ.get("HELION_INTERPRET") != "1", reason)
4646

4747

48+
def skipIfRocm(reason: str) -> Callable[[Callable], Callable]:
49+
"""Skip test if running with rocm"""
50+
return unittest.skipIf(torch.version.hip is not None, reason) # pyright: ignore[reportAttributeAccessIssue]
51+
52+
4853
@contextlib.contextmanager
4954
def track_run_ref_calls() -> Generator[list[int], None, None]:
5055
"""Context manager that tracks BoundKernel.run_ref calls.

test/test_autotuner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from helion._testing import RefEagerTestDisabled
1717
from helion._testing import TestCase
1818
from helion._testing import import_path
19+
from helion._testing import skipIfRocm
1920
from helion.autotuner import DifferentialEvolutionSearch
2021
from helion.autotuner.config_generation import ConfigGeneration
2122
from helion.autotuner.random_search import RandomSearch
@@ -36,6 +37,7 @@ def setUp(self):
3637
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
3738
@patch.object(_compat, "_min_dot_size", lambda *args: (16, 16, 16))
3839
@patch.object(loops, "_supports_warp_specialize", lambda: True)
40+
@skipIfRocm("failure on rocm")
3941
def test_config_fragment0(self):
4042
args = (
4143
torch.randn([512, 512], device=DEVICE),

test/test_dot.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from helion._testing import TestCase
1313
from helion._testing import code_and_output
1414
from helion._testing import skipIfRefEager
15+
from helion._testing import skipIfRocm
1516
import helion.language as hl
1617

1718

@@ -82,6 +83,7 @@ def make_test_function(input_dtype, acc_dtype, static_shapes_option):
8283
"""Create a test function for a specific combination of parameters."""
8384
combo = (input_dtype, input_dtype, acc_dtype)
8485

86+
@skipIfRocm("Core dumps with rocm -- https://github.com/pytorch/helion/issues/445")
8587
def test_impl(self):
8688
# Skip FP8 tests if GPU doesn't support it
8789
if (

test/test_examples.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from helion._testing import check_example
1313
from helion._testing import import_path
1414
from helion._testing import skipIfRefEager
15+
from helion._testing import skipIfRocm
1516

1617
torch.backends.cuda.matmul.fp32_precision = "tf32"
1718
torch.backends.cudnn.conv.fp32_precision = "tf32"
@@ -44,6 +45,7 @@ def test_matmul(self):
4445
)
4546
)
4647

48+
@skipIfRocm("failure on rocm")
4749
def test_matmul_layernorm_static_shapes(self):
4850
args = (
4951
torch.randn([128, 256], device=DEVICE, dtype=torch.float32),
@@ -66,6 +68,7 @@ def test_matmul_layernorm_static_shapes(self):
6668
)
6769
)
6870

71+
@skipIfRocm("failure on rocm")
6972
def test_matmul_layernorm_dynamic_shapes(self):
7073
args = (
7174
torch.randn([128, 256], device=DEVICE, dtype=torch.float32),
@@ -110,6 +113,7 @@ def test_bmm(self):
110113
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9,
111114
"FP8 requires GPU with compute capability >= 9.0 (e.g., H100)",
112115
)
116+
@skipIfRocm("failure on rocm")
113117
def test_fp8_gemm(self):
114118
# Create FP32 tensors and convert to FP8
115119
x = torch.randn([256, 256], device=DEVICE, dtype=torch.float32)
@@ -334,6 +338,7 @@ def test_embedding_block_ptr(self):
334338
)
335339
)
336340

341+
@skipIfRocm("failure on rocm")
337342
def test_attention_pointer(self):
338343
args = (
339344
torch.randn(1, 32, 512, 64, dtype=torch.float32, device=DEVICE),
@@ -568,6 +573,7 @@ def test_attention_persistent_interleaved_l2_grouping(self):
568573
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9,
569574
"FP8 requires GPU with compute capability >= 9.0 (e.g., H100)",
570575
)
576+
@skipIfRocm("failure on rocm")
571577
def test_fp8_attention(self):
572578
batch = 2
573579
heads = 4

test/test_indexing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from helion._testing import code_and_output
1414
from helion._testing import skipIfNormalMode
1515
from helion._testing import skipIfRefEager
16+
from helion._testing import skipIfRocm
1617
import helion.language as hl
1718

1819

@@ -626,6 +627,7 @@ def kernel(buf: torch.Tensor, zeros: torch.Tensor) -> torch.Tensor:
626627
expected = torch.zeros([N], device=DEVICE)
627628
torch.testing.assert_close(result, expected)
628629

630+
@skipIfRocm("failure on rocm")
629631
def test_1d_indexed_value_from_slice(self):
630632
"""buf2[i] = buf[:] - Assign slice to indexed value"""
631633

test/test_inline_asm_elementwise.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
from helion._testing import RefEagerTestDisabled
1111
from helion._testing import TestCase
1212
from helion._testing import code_and_output
13+
from helion._testing import skipIfRocm
1314
import helion.language as hl
1415

1516

1617
class TestInlineAsmElementwise(RefEagerTestDisabled, TestCase):
1718
@pytest.mark.skipif(
1819
DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA"
1920
)
21+
@skipIfRocm("only works on cuda")
2022
def test_inline_asm_simple(self):
2123
"""Test basic inline_asm_elementwise with simple assembly"""
2224

@@ -45,6 +47,7 @@ def kernel_simple_asm(x: torch.Tensor) -> torch.Tensor:
4547
@pytest.mark.skipif(
4648
DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA"
4749
)
50+
@skipIfRocm("only works on cuda")
4851
def test_inline_asm_shift_operation(self):
4952
"""Test inline_asm_elementwise with shift operation (similar to Triton test)"""
5053

@@ -82,6 +85,7 @@ def kernel_shift_asm(x: torch.Tensor, y: torch.Tensor, n: int) -> torch.Tensor:
8285
@pytest.mark.skipif(
8386
DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA"
8487
)
88+
@skipIfRocm("only works on cuda")
8589
def test_inline_asm_multiple_outputs(self):
8690
"""Test inline_asm_elementwise with multiple outputs"""
8791

@@ -130,6 +134,7 @@ def kernel_multiple_outputs(
130134
@pytest.mark.skipif(
131135
DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA"
132136
)
137+
@skipIfRocm("only works on cuda")
133138
def test_inline_asm_packed(self):
134139
"""Test inline_asm_elementwise with pack > 1"""
135140

@@ -186,6 +191,7 @@ def kernel_invalid_asm(x: torch.Tensor) -> torch.Tensor:
186191
@pytest.mark.skipif(
187192
DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA"
188193
)
194+
@skipIfRocm("only works on cuda")
189195
def test_inline_asm_empty_args(self):
190196
"""Test inline_asm_elementwise with empty args (should work like Triton)"""
191197

@@ -214,6 +220,7 @@ def kernel_empty_args(x: torch.Tensor) -> torch.Tensor:
214220
expected = torch.full([16], 42, dtype=torch.int32, device=DEVICE)
215221
torch.testing.assert_close(result, expected)
216222

223+
@skipIfRocm("only works on cuda")
217224
def test_inline_asm_basic_compilation(self):
218225
"""Test that inline_asm_elementwise compiles without errors (no CUDA requirement)"""
219226

test/test_print.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from helion._testing import RefEagerTestDisabled
1414
from helion._testing import TestCase
1515
from helion._testing import code_and_output
16+
from helion._testing import skipIfRocm
1617
import helion.language as hl
1718

1819

@@ -106,6 +107,7 @@ def run_test_with_and_without_triton_interpret_envvar(self, test_func):
106107
else:
107108
os.environ["TRITON_INTERPRET"] = original_env
108109

110+
@skipIfRocm("failure on rocm")
109111
def test_basic_print(self):
110112
"""Test basic print with prefix and tensor values"""
111113

@@ -142,6 +144,7 @@ def print_kernel(x: torch.Tensor) -> torch.Tensor:
142144

143145
self.run_test_with_and_without_triton_interpret_envvar(run_test)
144146

147+
@skipIfRocm("failure on rocm")
145148
def test_print_multiple_tensors(self):
146149
"""Test print with multiple tensor arguments"""
147150

@@ -248,6 +251,7 @@ def print_shape_kernel(x: torch.Tensor) -> torch.Tensor:
248251

249252
self.run_test_with_and_without_triton_interpret_envvar(run_test)
250253

254+
@skipIfRocm("failure on rocm")
251255
def test_print_prefix_only(self):
252256
def run_test(interpret_mode):
253257
@helion.kernel
@@ -280,6 +284,7 @@ def print_message_kernel(x: torch.Tensor) -> torch.Tensor:
280284

281285
self.run_test_with_and_without_triton_interpret_envvar(run_test)
282286

287+
@skipIfRocm("failure on rocm")
283288
def test_print_in_nested_loops(self):
284289
def run_test(interpret_mode):
285290
@helion.kernel
@@ -372,6 +377,7 @@ def print_outside_kernel(x: torch.Tensor) -> torch.Tensor:
372377

373378
self.run_test_with_and_without_triton_interpret_envvar(run_test)
374379

380+
@skipIfRocm("failure on rocm")
375381
def test_print_with_conditional(self):
376382
"""Test print with conditional statements"""
377383

@@ -431,6 +437,7 @@ def print_conditional_kernel(x: torch.Tensor) -> torch.Tensor:
431437

432438
self.run_test_with_and_without_triton_interpret_envvar(run_test)
433439

440+
@skipIfRocm("failure on rocm")
434441
def test_print_computed_values(self):
435442
"""Test print with computed/derived values"""
436443

@@ -523,6 +530,7 @@ def print_reduction_kernel(x: torch.Tensor) -> torch.Tensor:
523530

524531
self.run_test_with_and_without_triton_interpret_envvar(run_test)
525532

533+
@skipIfRocm("failure on rocm")
526534
def test_print_multiple_data_types(self):
527535
"""Test print with different tensor data types"""
528536

@@ -580,6 +588,7 @@ def print_dtypes_kernel(
580588

581589
self.run_test_with_and_without_triton_interpret_envvar(run_test)
582590

591+
@skipIfRocm("failure on rocm")
583592
def test_print_with_starred_args(self):
584593
"""Test print with starred/unpacked arguments"""
585594

test/test_register_tunable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from helion._testing import RefEagerTestBase
1212
from helion._testing import TestCase
1313
from helion._testing import code_and_output
14+
from helion._testing import skipIfRocm
1415
from helion.autotuner import EnumFragment
1516
from helion.autotuner import IntegerFragment
1617
from helion.autotuner import PowerOfTwoFragment
@@ -106,6 +107,7 @@ def fn(x: torch.Tensor):
106107
self.assertExpectedJournal(code)
107108
torch.testing.assert_close(result, x.sum())
108109

110+
@skipIfRocm("failure on rocm")
109111
def test_matmul_split_k(self):
110112
"""Test matmul_split_k kernel with register_tunable"""
111113

test/test_signal_wait.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
from helion._testing import RefEagerTestDisabled
1010
from helion._testing import TestCase
1111
from helion._testing import code_and_output
12+
from helion._testing import skipIfRocm
1213
import helion.language as hl
1314

1415

1516
class TestWait(RefEagerTestDisabled, TestCase):
17+
@skipIfRocm("only works on cuda")
1618
def test_wait_basic(self):
1719
@helion.kernel
1820
def gmem_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -32,6 +34,7 @@ def gmem_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
3234
self.maxDiff = None
3335
self.assertExpectedJournal(code)
3436

37+
@skipIfRocm("only works on cuda")
3538
def test_wait_2d_tile(self):
3639
@helion.kernel
3740
def wait_for_2d_tile_kernel(
@@ -55,6 +58,7 @@ def wait_for_2d_tile_kernel(
5558
torch.testing.assert_close(result, x)
5659
self.assertExpectedJournal(code)
5760

61+
@skipIfRocm("only works on cuda")
5862
def test_wait_multi_bar(self):
5963
@helion.kernel
6064
def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -78,6 +82,7 @@ def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
7882
self.maxDiff = None
7983
self.assertExpectedJournal(code)
8084

85+
@skipIfRocm("only works on cuda")
8186
def test_wait_multi_bar_cas(self):
8287
@helion.kernel
8388
def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -99,6 +104,7 @@ def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor) -> torch.Tensor:
99104
self.maxDiff = None
100105
self.assertExpectedJournal(code)
101106

107+
@skipIfRocm("only works on cuda")
102108
def test_signal_basic(self):
103109
@helion.kernel
104110
def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -114,6 +120,7 @@ def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
114120
)
115121
self.assertExpectedJournal(code)
116122

123+
@skipIfRocm("only works on cuda")
117124
def test_signal_cas(self):
118125
@helion.kernel
119126
def gmem_signal_cas_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -129,6 +136,7 @@ def gmem_signal_cas_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
129136
)
130137
self.assertExpectedJournal(code)
131138

139+
@skipIfRocm("only works on cuda")
132140
def test_signal_multiple(self):
133141
@helion.kernel
134142
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -148,6 +156,7 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
148156
)
149157
self.assertExpectedJournal(code)
150158

159+
@skipIfRocm("only works on cuda")
151160
def test_signal_multiple_cas(self):
152161
@helion.kernel
153162
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -167,6 +176,7 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
167176
)
168177
self.assertExpectedJournal(code)
169178

179+
@skipIfRocm("only works on cuda")
170180
def test_send_recieve_cta(self):
171181
@helion.kernel
172182
def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -186,6 +196,7 @@ def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
186196
self.assertIn("helion.runtime.triton_send_signal", code)
187197
self.assertIn("helion.runtime.triton_wait_signal", code)
188198

199+
@skipIfRocm("only works on cuda")
189200
def test_global_sync(self):
190201
@helion.kernel
191202
def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -207,6 +218,7 @@ def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
207218
)
208219
self.assertExpectedJournal(code)
209220

221+
@skipIfRocm("only works on cuda")
210222
def test_global_sync_cas(self):
211223
@helion.kernel
212224
def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -232,6 +244,7 @@ def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
232244
)
233245
self.assertIn("atomic_cas", code)
234246

247+
@skipIfRocm("only works on cuda")
235248
def test_wait_stack_signalpad(self):
236249
@helion.kernel
237250
def gmem_wait_pointers_kernel(
@@ -259,6 +272,7 @@ def gmem_wait_pointers_kernel(
259272
)
260273
self.assertExpectedJournal(code)
261274

275+
@skipIfRocm("only works on cuda")
262276
def test_signal_stack_signalpad(self):
263277
@helion.kernel
264278
def gmem_signal_pointers_kernel(

0 commit comments

Comments
 (0)