Skip to content

Commit d629bda

Browse files
authored
[Gluon] Add missing module attributes + verify IR in frontend tests (#7057)
1 parent 65a80e4 commit d629bda

File tree

6 files changed

+72
-35
lines changed

6 files changed

+72
-35
lines changed

python/src/ir.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,10 @@ void init_triton_ir(py::module &&m) {
756756
[](TritonOpBuilder &self, int32_t value) {
757757
return self.getBuilder().getI32IntegerAttr(value);
758758
})
759+
.def("get_string_attr",
760+
[](TritonOpBuilder &self, std::string value) -> Attribute {
761+
return self.getBuilder().getStringAttr(value);
762+
})
759763
// Use arith.ConstantOp to create constants
760764
// Constants
761765
.def("get_int1",

python/test/gluon/test_frontend.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import expecttest
22
import torch
33
import pytest
4+
import re
45

56
from triton import knobs
67
from triton.experimental import gluon
@@ -13,6 +14,12 @@
1314
from triton.tools.tensor_descriptor import TensorDescriptor
1415
from triton.compiler.errors import CompilationError
1516

17+
TARGET_PAT = re.compile('ttg.target = "[^"]*"')
18+
19+
20+
def anonymize_ir(ir):
21+
return TARGET_PAT.sub('ttg.target = "..."', ir)
22+
1623

1724
@gluon.jit
1825
def convert_layout_kernel(XBLOCK: ttgl.constexpr, layout_a: ttgl.constexpr, layout_b: ttgl.constexpr):
@@ -28,10 +35,10 @@ def test_convert_layout(fresh_knobs):
2835
1, ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[1, 32], warps_per_cta=[1, 4], order=[1, 0]))
2936
h = convert_layout_kernel.warmup(128, layout_a, layout_b, num_warps=layout_a.warps_per_cta[0], grid=(1, ))
3037
expecttest.assert_expected_inline(
31-
h.asm["source"], """\
38+
anonymize_ir(h.asm["source"]), """\
3239
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
3340
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
34-
module attributes {"ttg.num-warps" = 4 : i32} {
41+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
3542
tt.func public @convert_layout_kernel() attributes {noinline = false} {
3643
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> loc(#loc)
3744
%1 = ttg.convert_layout %0 : tensor<128xi32, #blocked> -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc)
@@ -41,8 +48,8 @@ def test_convert_layout(fresh_knobs):
4148
#loc = loc(unknown)
4249
""")
4350
expecttest.assert_expected_inline(
44-
h.asm["ttgir"], """\
45-
module attributes {"ttg.num-warps" = 4 : i32} {
51+
anonymize_ir(h.asm["ttgir"]), """\
52+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
4653
tt.func public @convert_layout_kernel() attributes {noinline = false} {
4754
tt.return loc(#loc)
4855
} loc(#loc)
@@ -71,12 +78,12 @@ def test_shared_memory(fresh_knobs):
7178
h = shared_memory_kernel.warmup(8, 32, layout_a, layout_b, smem_layout, num_warps=layout_a.warps_per_cta[0],
7279
grid=(1, ))
7380
expecttest.assert_expected_inline(
74-
h.asm["source"], """\
81+
anonymize_ir(h.asm["source"]), """\
7582
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
7683
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
7784
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
7885
#smem = #ttg.shared_memory
79-
module attributes {"ttg.num-warps" = 4 : i32} {
86+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
8087
tt.func public @shared_memory_kernel() attributes {noinline = false} {
8188
%0 = ttg.local_alloc : () -> !ttg.memdesc<8x32xi32, #shared, #smem, mutable> loc(#loc)
8289
%c0_i32 = arith.constant 0 : i32 loc(#loc)
@@ -118,10 +125,10 @@ def test_tensor_memory(fresh_knobs):
118125
tmem_layout = ttgl.nvidia.blackwell.TensorMemoryLayout(block=[128, 128], unpacked=True)
119126
h = tensor_memory_kernel.warmup(layout, tmem_layout, num_warps=4, grid=(1, ))
120127
expecttest.assert_expected_inline(
121-
h.asm["source"], """\
128+
anonymize_ir(h.asm["source"]), """\
122129
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
123130
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
124-
module attributes {"ttg.num-warps" = 4 : i32} {
131+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
125132
tt.func public @tensor_memory_kernel() attributes {noinline = false} {
126133
%c0_i32 = arith.constant 0 : i32 loc(#loc)
127134
%cst = arith.constant dense<0> : tensor<128x128xi32, #blocked> loc(#loc)
@@ -154,7 +161,7 @@ def test_tensor_memory(fresh_knobs):
154161

155162
@gluon.jit
156163
def shared_memory_subview_kernel(XBLOCK: ttgl.constexpr, layout: ttgl.constexpr, smem_layout: ttgl.constexpr):
157-
XHALF: tl.constexpr = XBLOCK // 2
164+
XHALF: ttgl.constexpr = XBLOCK // 2
158165
smem = ttgl.allocate_shared_memory(ttgl.int32, [XBLOCK, XBLOCK], smem_layout)
159166
view = smem.split(XHALF, XHALF, dim=1)
160167
value = view.load(layout)
@@ -169,12 +176,12 @@ def test_shared_memory_subview(fresh_knobs):
169176
smem_layout = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0])
170177
h = shared_memory_subview_kernel.warmup(256, layout, smem_layout, num_warps=4, grid=(1, ))
171178
expecttest.assert_expected_inline(
172-
h.asm["source"], """\
179+
anonymize_ir(h.asm["source"]), """\
173180
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
174181
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
175182
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
176183
#smem = #ttg.shared_memory
177-
module attributes {"ttg.num-warps" = 4 : i32} {
184+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
178185
tt.func public @shared_memory_subview_kernel() attributes {noinline = false} {
179186
%0 = ttg.local_alloc : () -> !ttg.memdesc<256x256xi32, #shared, #smem, mutable> loc(#loc)
180187
%c0_i32 = arith.constant 0 : i32 loc(#loc)
@@ -207,11 +214,11 @@ def test_shared_memory_subslice(fresh_knobs):
207214
smem_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2)
208215
h = shared_memory_subslice_kernel.warmup(256, layout, smem_layout, num_warps=4, grid=(1, ))
209216
expecttest.assert_expected_inline(
210-
h.asm["source"], """\
217+
anonymize_ir(h.asm["source"]), """\
211218
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
212219
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
213220
#smem = #ttg.shared_memory
214-
module attributes {"ttg.num-warps" = 4 : i32} {
221+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
215222
tt.func public @shared_memory_subslice_kernel() attributes {noinline = false} {
216223
%0 = ttg.local_alloc : () -> !ttg.memdesc<4x256xi32, #shared, #smem, mutable> loc(#loc)
217224
%c0_i32 = arith.constant 0 : i32 loc(#loc)
@@ -254,14 +261,14 @@ def shared_memory_cast_kernel():
254261

255262
def test_shared_memory_cast(fresh_knobs):
256263
expecttest.assert_expected_inline(
257-
run_parser(shared_memory_cast_kernel).str_nodebug(), """\
264+
anonymize_ir(run_parser(shared_memory_cast_kernel).str_nodebug()), """\
258265
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
259266
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8}>
260267
#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 1, 1, 1], CTASplitNum = [1, 1, 1, 1], CTAOrder = [3, 2, 1, 0]}>
261268
#shared3 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}>
262269
#shared4 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
263270
#smem = #ttg.shared_memory
264-
module {
271+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
265272
tt.func public @shared_memory_cast_kernel() attributes {noinline = false} {
266273
%0 = ttg.local_alloc : () -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable>
267274
%1 = ttg.memdesc_trans %0 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>
@@ -307,6 +314,7 @@ def anchor(x):
307314
@filecheck_test
308315
@gluon.jit
309316
def test_warp_specialize():
317+
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
310318
# CHECK-LABEL: test_warp_specialize
311319
# CHECK-NEXT: [[A:%.*]] = tt.make_range {end = 1 : i32, start = 0 : i32}
312320
# CHECK-NEXT: [[B:%.*]] = tt.make_range {end = 2 : i32, start = 0 : i32}
@@ -316,19 +324,23 @@ def test_warp_specialize():
316324
# CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}}warp_specialize_default{{.*}}([[A]], [[B]], [[C]])
317325
# CHECK-NEXT: warp_yield [[RESULTS]]#0, [[RESULTS]]#1, [[RESULTS]]#2
318326
# CHECK-NEXT: }
319-
# CHECK-NEXT: partition0(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
327+
# CHECK-NEXT: partition0(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>) num_warps(4) {
320328
# CHECK-NEXT: call @{{.*}}warp_specialize_worker0{{.*}}(%arg0, %arg1, %arg2)
321329
# CHECK-NEXT: warp_return
322330
# CHECK-NEXT: }
323-
# CHECK-NEXT: partition1(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
331+
# CHECK-NEXT: partition1(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>) num_warps(4) {
324332
# CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}(%arg0, %arg1, %arg2)
325333
# CHECK-NEXT: warp_return
326334
# CHECK-NEXT: }
327335
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#0)
328336
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#1, [[OUTS]]#2)
329-
pair = Pair(tl.arange(0, 1), tl.arange(0, 2))
330-
a, b = ttgl.warp_specialize((pair, tl.arange(0, 4)), warp_specialize_default,
331-
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
337+
layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0])
338+
a = ttgl.arange(0, 1, layout=layout)
339+
b = ttgl.arange(0, 2, layout=layout)
340+
c = ttgl.arange(0, 4, layout=layout)
341+
pair = Pair(a, b)
342+
a, b = ttgl.warp_specialize((pair, c), warp_specialize_default, [warp_specialize_worker0, warp_specialize_worker1],
343+
[4, 4], [24, 48])
332344
anchor(a)
333345
anchor(b)
334346

@@ -350,10 +362,10 @@ def test_mbarrier(fresh_knobs):
350362

351363
h = mbarrier_kernel.warmup(grid=(1, ))
352364
expecttest.assert_expected_inline(
353-
h.asm["source"], """\
365+
anonymize_ir(h.asm["source"]), """\
354366
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
355367
#smem = #ttg.shared_memory
356-
module attributes {"ttg.num-warps" = 4 : i32} {
368+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
357369
tt.func public @mbarrier_kernel() attributes {noinline = false} {
358370
%0 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
359371
ttng.init_barrier %0, 1 : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
@@ -390,11 +402,11 @@ def test_tcgen05_mma(fresh_knobs):
390402

391403
h = tcgen05_mma_kernel.warmup(nvmma_layout, acc_layout, grid=(1, ))
392404
expecttest.assert_expected_inline(
393-
h.asm["source"], """\
405+
anonymize_ir(h.asm["source"]), """\
394406
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
395407
#smem = #ttg.shared_memory
396408
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
397-
module attributes {"ttg.num-warps" = 4 : i32} {
409+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
398410
tt.func public @tcgen05_mma_kernel() attributes {noinline = false} {
399411
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
400412
%1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
@@ -436,12 +448,12 @@ def test_async_tma(fresh_knobs):
436448

437449
h = async_tma_kernel.warmup(input_desc, XBLOCK, shared_layout, grid=(1, ), num_warps=4)
438450
expecttest.assert_expected_inline(
439-
h.asm["source"], """\
451+
anonymize_ir(h.asm["source"]), """\
440452
#loc = loc(unknown)
441453
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
442454
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
443455
#smem = #ttg.shared_memory
444-
module attributes {"ttg.num-warps" = 4 : i32} {
456+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
445457
tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
446458
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
447459
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
@@ -472,7 +484,7 @@ def async_tma_blackwell_kernel(input_desc, XBLOCK: ttgl.constexpr, smem_layout:
472484
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
473485
mbarrier.init(bar, count=1)
474486

475-
offset_layout: tl.constexpr = ttgl.BlockedLayout([1, 4], [32, 1], [1, 4], [1, 0])
487+
offset_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 4], [32, 1], [1, 4], [1, 0])
476488
x_offsets = ttgl.arange(0, XBLOCK, layout=ttgl.SliceLayout(0, offset_layout))
477489
tma.async_gather(input_desc, x_offsets, 0, bar, smem)
478490
mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float16.primitive_bitwidth // 8)
@@ -495,13 +507,13 @@ def test_async_tma_blackwell(fresh_knobs):
495507

496508
h = async_tma_blackwell_kernel.warmup(input_desc, XBLOCK, shared_layout, grid=(1, ), num_warps=4)
497509
expecttest.assert_expected_inline(
498-
h.asm["source"], """\
510+
anonymize_ir(h.asm["source"]), """\
499511
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
500512
#loc = loc(unknown)
501513
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
502514
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
503515
#smem = #ttg.shared_memory
504-
module attributes {"ttg.num-warps" = 4 : i32} {
516+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
505517
tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
506518
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
507519
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
@@ -546,9 +558,9 @@ def tmem_subslice_kernel():
546558

547559
def test_tmem_subslice_constexpr():
548560
expecttest.assert_expected_inline(
549-
run_parser(tmem_subslice_kernel).str_nodebug(), """\
561+
anonymize_ir(run_parser(tmem_subslice_kernel).str_nodebug()), """\
550562
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
551-
module {
563+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
552564
tt.func public @tmem_subslice_kernel() attributes {noinline = false} {
553565
%result = ttng.tmem_alloc : () -> !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable>
554566
%c0_i32 = arith.constant 0 : i32
@@ -574,10 +586,10 @@ def kernel():
574586
smem_and_layout_user(smem, a)
575587

576588
expecttest.assert_expected_inline(
577-
run_parser(kernel).str_nodebug(), """\
589+
anonymize_ir(run_parser(kernel).str_nodebug()), """\
578590
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
579591
#smem = #ttg.shared_memory
580-
module {
592+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
581593
tt.func public @kernel() attributes {noinline = false} {
582594
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
583595
tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()

python/triton/_filecheck.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import triton
1212
from triton.compiler import ASTSource, make_backend
1313
from triton.backends.compiler import GPUTarget
14+
from triton.experimental.gluon._runtime import GluonASTSource
1415
from triton._C.libtriton import ir
1516

1617
# ===-----------------------------------------------------------------------===#
@@ -50,7 +51,8 @@ def run_parser(kernel_fn):
5051
sigkeys = [x.name for x in kernel_fn.params]
5152
sigvals = [f"arg{i}" for i in range(len(sigkeys))]
5253
signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
53-
src = ASTSource(fn=kernel_fn, signature=signature)
54+
source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource
55+
src = source_cls(fn=kernel_fn, signature=signature)
5456

5557
context = ir.context()
5658
ir.load_dialects(context)
@@ -60,7 +62,9 @@ def run_parser(kernel_fn):
6062
options = stub_backend.parse_options(dict(**extra_options))
6163
codegen_fns = stub_backend.get_codegen_implementation(options)
6264
module_map = stub_backend.get_module_map()
63-
return src.make_ir(options, codegen_fns, module_map, context)
65+
module = src.make_ir(options, codegen_fns, module_map, context)
66+
assert module.verify()
67+
return module
6468

6569

6670
def run_filecheck_test(kernel_fn):

python/triton/experimental/gluon/_runtime.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import triton
12
from triton.compiler.code_generator import ast_to_ttir
23
from triton.compiler.compiler import ASTSource
34
from triton.backends.compiler import Language
@@ -16,10 +17,19 @@ def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
1617
self.ext = "ttgir"
1718

1819
def make_ir(self, options, codegen_fns, module_map, context):
20+
from triton.compiler.compiler import make_backend
1921
module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
2022
module_map=module_map)
2123
builder = ir.builder(context)
24+
target = triton.runtime.driver.active.get_current_target()
25+
backend = make_backend(target)
26+
target = backend.get_target_name(options)
27+
module.set_attr("ttg.target", builder.get_string_attr(target))
2228
module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
29+
module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
30+
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(32))
31+
if options.maxnreg is not None:
32+
module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
2333
return module
2434

2535

third_party/amd/backend/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def __init__(self, target: GPUTarget) -> None:
9898
assert isinstance(target.arch, str)
9999
self.binary_ext = "hsaco"
100100

101+
def get_target_name(self, options) -> str:
102+
return f"hip:{options.arch}"
103+
101104
def parse_options(self, opts) -> Any:
102105
args = {'arch': knobs.runtime.override_arch or self.target.arch}
103106

third_party/nvidia/backend/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ def _parse_arch(self, arch):
151151
raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
152152
return int(match.group(1))
153153

154+
def get_target_name(self, options) -> str:
155+
capability = self._parse_arch(options.arch)
156+
return f"cuda:{capability}"
157+
154158
def __init__(self, target: GPUTarget) -> None:
155159
super().__init__(target)
156160
self.binary_ext = "cubin"

0 commit comments

Comments
 (0)