Skip to content

Commit aaa3d82

Browse files
authored
[gluon] fix some AMD compilation issues + skip tests on AMD for now (#7215)
Fixes some minor AMD compilation issues. Some tests in test_frontend are skipped since they hardcode tpw=32 (and some use nvidia layouts), so I'm skipping these tests for now (they should probably be re-enabled in the future for AMD where possible).
1 parent c8a711d commit aaa3d82

File tree

3 files changed

+31
-10
lines changed

3 files changed

+31
-10
lines changed

python/test/gluon/test_core.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,18 @@ def copy_kernel(Out, In, numel, XBLOCK: ttgl.constexpr, layout: ttgl.constexpr):
1717
ttgl.store(Out + xoffset, data, xmask)
1818

1919

20+
copy_kernel_tpw = [32] if is_cuda() else [64]
21+
22+
2023
@pytest.mark.parametrize("layout", [
21-
ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0]),
22-
ttgl.BlockedLayout(size_per_thread=[2], threads_per_warp=[32], warps_per_cta=[4], order=[0]),
23-
ttgl.BlockedLayout(size_per_thread=[4], threads_per_warp=[32], warps_per_cta=[4], order=[0]),
24-
ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=[32], warps_per_cta=[4], order=[0]),
25-
ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[8], order=[0]),
26-
ttgl.BlockedLayout(size_per_thread=[2], threads_per_warp=[32], warps_per_cta=[8], order=[0]),
27-
ttgl.BlockedLayout(size_per_thread=[4], threads_per_warp=[32], warps_per_cta=[8], order=[0]),
28-
ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=[32], warps_per_cta=[8], order=[0]),
24+
ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=copy_kernel_tpw, warps_per_cta=[4], order=[0]),
25+
ttgl.BlockedLayout(size_per_thread=[2], threads_per_warp=copy_kernel_tpw, warps_per_cta=[4], order=[0]),
26+
ttgl.BlockedLayout(size_per_thread=[4], threads_per_warp=copy_kernel_tpw, warps_per_cta=[4], order=[0]),
27+
ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=copy_kernel_tpw, warps_per_cta=[4], order=[0]),
28+
ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=copy_kernel_tpw, warps_per_cta=[8], order=[0]),
29+
ttgl.BlockedLayout(size_per_thread=[2], threads_per_warp=copy_kernel_tpw, warps_per_cta=[8], order=[0]),
30+
ttgl.BlockedLayout(size_per_thread=[4], threads_per_warp=copy_kernel_tpw, warps_per_cta=[8], order=[0]),
31+
ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=copy_kernel_tpw, warps_per_cta=[8], order=[0]),
2932
])
3033
@pytest.mark.parametrize("XBLOCK", [128, 256, 512, 1024, 2048])
3134
def test_copy_kernel(layout, XBLOCK):

python/test/gluon/test_frontend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def convert_layout_kernel(XBLOCK: ttgl.constexpr, layout_a: ttgl.constexpr, layo
2828
res = ttgl.convert_layout(x, layout_b) # noqa: F841
2929

3030

31+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
3132
def test_convert_layout(fresh_knobs):
3233
knobs.compilation.disable_line_info = True
3334

@@ -70,6 +71,7 @@ def shared_memory_kernel(XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr, layout_
7071
unused._keep_alive()
7172

7273

74+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
7375
def test_shared_memory(fresh_knobs):
7476
knobs.compilation.disable_line_info = True
7577

@@ -170,6 +172,7 @@ def shared_memory_subview_kernel(XBLOCK: ttgl.constexpr, layout: ttgl.constexpr,
170172
view.store(value.trans())
171173

172174

175+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
173176
def test_shared_memory_subview(fresh_knobs):
174177
knobs.compilation.disable_line_info = True
175178

@@ -208,6 +211,7 @@ def shared_memory_index_kernel(XBLOCK: ttgl.constexpr, layout: ttgl.constexpr, s
208211
smem.index(i).load(layout)
209212

210213

214+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
211215
def test_shared_memory_index(fresh_knobs):
212216
knobs.compilation.disable_line_info = True
213217

@@ -263,6 +267,7 @@ def shared_memory_cast_kernel():
263267
smem._reinterpret(ttgl.int8, [1024], ttgl.SwizzledSharedLayout(1, 1, 1, [0, 1]))
264268

265269

270+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
266271
def test_shared_memory_cast(fresh_knobs):
267272
expecttest.assert_expected_inline(
268273
anonymize_ir(run_parser(shared_memory_cast_kernel).str_nodebug()), """\
@@ -630,6 +635,7 @@ def broadcast_kernel():
630635
0 + a + b
631636

632637

638+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
633639
def test_broadcast(fresh_knobs):
634640
knobs.compilation.disable_line_info = True
635641

@@ -684,6 +690,7 @@ def math_kernel():
684690
ttgl.fma(a, b, c)
685691

686692

693+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
687694
def test_math(fresh_knobs):
688695
knobs.compilation.disable_line_info = True
689696

@@ -754,6 +761,7 @@ def reduce_kernel(out):
754761
tl.store(out + ttgl.arange(0, 16, s0.type.layout), result)
755762

756763

764+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
757765
def test_reduce(fresh_knobs):
758766
knobs.compilation.disable_line_info = True
759767

@@ -802,6 +810,7 @@ def test_reduce(fresh_knobs):
802810
""")
803811

804812

813+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
805814
@filecheck_test
806815
@gluon.jit
807816
def test_elementwise_core():
@@ -829,6 +838,7 @@ def linear_layout_kernel():
829838
ttgl.arange(0, 256, layout=ll)
830839

831840

841+
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
832842
def test_linear_layout(fresh_knobs):
833843
knobs.compilation.disable_line_info = True
834844
h = linear_layout_kernel.warmup(grid=(1, ))

python/triton/experimental/gluon/_runtime.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,19 @@ def make_ir(self, options, codegen_fns, module_map, context):
2727
target = triton.runtime.driver.active.get_current_target()
2828
backend = make_backend(target)
2929
target = backend.get_target_name(options)
30+
3031
module.set_attr("ttg.target", builder.get_string_attr(target))
3132
module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
3233
module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
33-
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(32))
34-
if options.maxnreg is not None:
34+
35+
is_cuda = options.backend_name == "cuda"
36+
37+
if is_cuda:
38+
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(32))
39+
else:
40+
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(64))
41+
42+
if is_cuda and options.maxnreg is not None:
3543
module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
3644

3745
module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,

0 commit comments

Comments
 (0)