1
1
import expecttest
2
2
import torch
3
3
import pytest
4
+ import re
4
5
5
6
from triton import knobs
6
7
from triton .experimental import gluon
13
14
from triton .tools .tensor_descriptor import TensorDescriptor
14
15
from triton .compiler .errors import CompilationError
15
16
17
+ TARGET_PAT = re .compile ('ttg.target = "[^"]*"' )
18
+
19
+
20
+ def anonymize_ir (ir ):
21
+ return TARGET_PAT .sub ('ttg.target = "..."' , ir )
22
+
16
23
17
24
@gluon .jit
18
25
def convert_layout_kernel (XBLOCK : ttgl .constexpr , layout_a : ttgl .constexpr , layout_b : ttgl .constexpr ):
@@ -28,10 +35,10 @@ def test_convert_layout(fresh_knobs):
28
35
1 , ttgl .BlockedLayout (size_per_thread = [1 , 1 ], threads_per_warp = [1 , 32 ], warps_per_cta = [1 , 4 ], order = [1 , 0 ]))
29
36
h = convert_layout_kernel .warmup (128 , layout_a , layout_b , num_warps = layout_a .warps_per_cta [0 ], grid = (1 , ))
30
37
expecttest .assert_expected_inline (
31
- h .asm ["source" ], """\
38
+ anonymize_ir ( h .asm ["source" ]) , """\
32
39
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
33
40
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
34
- module attributes {"ttg.num-warps" = 4 : i32} {
41
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num- warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
35
42
tt.func public @convert_layout_kernel() attributes {noinline = false} {
36
43
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> loc(#loc)
37
44
%1 = ttg.convert_layout %0 : tensor<128xi32, #blocked> -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc)
@@ -41,8 +48,8 @@ def test_convert_layout(fresh_knobs):
41
48
#loc = loc(unknown)
42
49
""" )
43
50
expecttest .assert_expected_inline (
44
- h .asm ["ttgir" ], """\
45
- module attributes {"ttg.num-warps" = 4 : i32} {
51
+ anonymize_ir ( h .asm ["ttgir" ]) , """\
52
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num- warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
46
53
tt.func public @convert_layout_kernel() attributes {noinline = false} {
47
54
tt.return loc(#loc)
48
55
} loc(#loc)
@@ -71,12 +78,12 @@ def test_shared_memory(fresh_knobs):
71
78
h = shared_memory_kernel .warmup (8 , 32 , layout_a , layout_b , smem_layout , num_warps = layout_a .warps_per_cta [0 ],
72
79
grid = (1 , ))
73
80
expecttest .assert_expected_inline (
74
- h .asm ["source" ], """\
81
+ anonymize_ir ( h .asm ["source" ]) , """\
75
82
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
76
83
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
77
84
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
78
85
#smem = #ttg.shared_memory
79
- module attributes {"ttg.num-warps" = 4 : i32} {
86
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num- warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
80
87
tt.func public @shared_memory_kernel() attributes {noinline = false} {
81
88
%0 = ttg.local_alloc : () -> !ttg.memdesc<8x32xi32, #shared, #smem, mutable> loc(#loc)
82
89
%c0_i32 = arith.constant 0 : i32 loc(#loc)
@@ -118,10 +125,10 @@ def test_tensor_memory(fresh_knobs):
118
125
tmem_layout = ttgl .nvidia .blackwell .TensorMemoryLayout (block = [128 , 128 ], unpacked = True )
119
126
h = tensor_memory_kernel .warmup (layout , tmem_layout , num_warps = 4 , grid = (1 , ))
120
127
expecttest .assert_expected_inline (
121
- h .asm ["source" ], """\
128
+ anonymize_ir ( h .asm ["source" ]) , """\
122
129
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
123
130
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
124
- module attributes {"ttg.num-warps" = 4 : i32} {
131
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num- warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
125
132
tt.func public @tensor_memory_kernel() attributes {noinline = false} {
126
133
%c0_i32 = arith.constant 0 : i32 loc(#loc)
127
134
%cst = arith.constant dense<0> : tensor<128x128xi32, #blocked> loc(#loc)
@@ -154,7 +161,7 @@ def test_tensor_memory(fresh_knobs):
154
161
155
162
@gluon .jit
156
163
def shared_memory_subview_kernel (XBLOCK : ttgl .constexpr , layout : ttgl .constexpr , smem_layout : ttgl .constexpr ):
157
- XHALF : tl .constexpr = XBLOCK // 2
164
+ XHALF : ttgl .constexpr = XBLOCK // 2
158
165
smem = ttgl .allocate_shared_memory (ttgl .int32 , [XBLOCK , XBLOCK ], smem_layout )
159
166
view = smem .split (XHALF , XHALF , dim = 1 )
160
167
value = view .load (layout )
@@ -169,12 +176,12 @@ def test_shared_memory_subview(fresh_knobs):
169
176
smem_layout = ttgl .SwizzledSharedLayout (1 , 1 , 1 , [1 , 0 ])
170
177
h = shared_memory_subview_kernel .warmup (256 , layout , smem_layout , num_warps = 4 , grid = (1 , ))
171
178
expecttest .assert_expected_inline (
172
- h .asm ["source" ], """\
179
+ anonymize_ir ( h .asm ["source" ]) , """\
173
180
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
174
181
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
175
182
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
176
183
#smem = #ttg.shared_memory
177
- module attributes {"ttg.num-warps" = 4 : i32} {
184
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num- warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
178
185
tt.func public @shared_memory_subview_kernel() attributes {noinline = false} {
179
186
%0 = ttg.local_alloc : () -> !ttg.memdesc<256x256xi32, #shared, #smem, mutable> loc(#loc)
180
187
%c0_i32 = arith.constant 0 : i32 loc(#loc)
@@ -207,11 +214,11 @@ def test_shared_memory_subslice(fresh_knobs):
207
214
smem_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 32 , rank = 2 )
208
215
h = shared_memory_subslice_kernel .warmup (256 , layout , smem_layout , num_warps = 4 , grid = (1 , ))
209
216
expecttest .assert_expected_inline (
210
- h .asm ["source" ], """\
217
+ anonymize_ir ( h .asm ["source" ]) , """\
211
218
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
212
219
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
213
220
#smem = #ttg.shared_memory
214
- module attributes {"ttg.num-warps" = 4 : i32} {
221
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num- warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
215
222
tt.func public @shared_memory_subslice_kernel() attributes {noinline = false} {
216
223
%0 = ttg.local_alloc : () -> !ttg.memdesc<4x256xi32, #shared, #smem, mutable> loc(#loc)
217
224
%c0_i32 = arith.constant 0 : i32 loc(#loc)
@@ -254,14 +261,14 @@ def shared_memory_cast_kernel():
254
261
255
262
def test_shared_memory_cast (fresh_knobs ):
256
263
expecttest .assert_expected_inline (
257
- run_parser (shared_memory_cast_kernel ).str_nodebug (), """\
264
+ anonymize_ir ( run_parser (shared_memory_cast_kernel ).str_nodebug () ), """\
258
265
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
259
266
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8}>
260
267
#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 1, 1, 1], CTASplitNum = [1, 1, 1, 1], CTAOrder = [3, 2, 1, 0]}>
261
268
#shared3 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}>
262
269
#shared4 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
263
270
#smem = #ttg.shared_memory
264
- module {
271
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
265
272
tt.func public @shared_memory_cast_kernel() attributes {noinline = false} {
266
273
%0 = ttg.local_alloc : () -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable>
267
274
%1 = ttg.memdesc_trans %0 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>
@@ -307,6 +314,7 @@ def anchor(x):
307
314
@filecheck_test
308
315
@gluon .jit
309
316
def test_warp_specialize ():
317
+ # CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
310
318
# CHECK-LABEL: test_warp_specialize
311
319
# CHECK-NEXT: [[A:%.*]] = tt.make_range {end = 1 : i32, start = 0 : i32}
312
320
# CHECK-NEXT: [[B:%.*]] = tt.make_range {end = 2 : i32, start = 0 : i32}
@@ -316,19 +324,23 @@ def test_warp_specialize():
316
324
# CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}}warp_specialize_default{{.*}}([[A]], [[B]], [[C]])
317
325
# CHECK-NEXT: warp_yield [[RESULTS]]#0, [[RESULTS]]#1, [[RESULTS]]#2
318
326
# CHECK-NEXT: }
319
- # CHECK-NEXT: partition0(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
327
+ # CHECK-NEXT: partition0(%arg0: tensor<1xi32, [[BLOCKED]] >, %arg1: tensor<2xi32, [[BLOCKED]] >, %arg2: tensor<4xi32, [[BLOCKED]] >) num_warps(4) {
320
328
# CHECK-NEXT: call @{{.*}}warp_specialize_worker0{{.*}}(%arg0, %arg1, %arg2)
321
329
# CHECK-NEXT: warp_return
322
330
# CHECK-NEXT: }
323
- # CHECK-NEXT: partition1(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
331
+ # CHECK-NEXT: partition1(%arg0: tensor<1xi32, [[BLOCKED]] >, %arg1: tensor<2xi32, [[BLOCKED]] >, %arg2: tensor<4xi32, [[BLOCKED]] >) num_warps(4) {
324
332
# CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}(%arg0, %arg1, %arg2)
325
333
# CHECK-NEXT: warp_return
326
334
# CHECK-NEXT: }
327
335
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#0)
328
336
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#1, [[OUTS]]#2)
329
- pair = Pair (tl .arange (0 , 1 ), tl .arange (0 , 2 ))
330
- a , b = ttgl .warp_specialize ((pair , tl .arange (0 , 4 )), warp_specialize_default ,
331
- [warp_specialize_worker0 , warp_specialize_worker1 ], [4 , 4 ], [24 , 48 ])
337
+ layout : ttgl .constexpr = ttgl .BlockedLayout ([1 ], [32 ], [4 ], [0 ])
338
+ a = ttgl .arange (0 , 1 , layout = layout )
339
+ b = ttgl .arange (0 , 2 , layout = layout )
340
+ c = ttgl .arange (0 , 4 , layout = layout )
341
+ pair = Pair (a , b )
342
+ a , b = ttgl .warp_specialize ((pair , c ), warp_specialize_default , [warp_specialize_worker0 , warp_specialize_worker1 ],
343
+ [4 , 4 ], [24 , 48 ])
332
344
anchor (a )
333
345
anchor (b )
334
346
@@ -350,10 +362,10 @@ def test_mbarrier(fresh_knobs):
350
362
351
363
h = mbarrier_kernel .warmup (grid = (1 , ))
352
364
expecttest .assert_expected_inline (
353
- h .asm ["source" ], """\
365
+ anonymize_ir ( h .asm ["source" ]) , """\
354
366
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
355
367
#smem = #ttg.shared_memory
356
- module attributes {"ttg.num-warps" = 4 : i32} {
368
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num- warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
357
369
tt.func public @mbarrier_kernel() attributes {noinline = false} {
358
370
%0 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
359
371
ttng.init_barrier %0, 1 : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
@@ -390,11 +402,11 @@ def test_tcgen05_mma(fresh_knobs):
390
402
391
403
h = tcgen05_mma_kernel .warmup (nvmma_layout , acc_layout , grid = (1 , ))
392
404
expecttest .assert_expected_inline (
393
- h .asm ["source" ], """\
405
+ anonymize_ir ( h .asm ["source" ]) , """\
394
406
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
395
407
#smem = #ttg.shared_memory
396
408
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
397
- module attributes {"ttg.num-warps" = 4 : i32} {
409
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num- warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
398
410
tt.func public @tcgen05_mma_kernel() attributes {noinline = false} {
399
411
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
400
412
%1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
@@ -436,12 +448,12 @@ def test_async_tma(fresh_knobs):
436
448
437
449
h = async_tma_kernel .warmup (input_desc , XBLOCK , shared_layout , grid = (1 , ), num_warps = 4 )
438
450
expecttest .assert_expected_inline (
439
- h .asm ["source" ], """\
451
+ anonymize_ir ( h .asm ["source" ]) , """\
440
452
#loc = loc(unknown)
441
453
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
442
454
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
443
455
#smem = #ttg.shared_memory
444
- module attributes {"ttg.num-warps" = 4 : i32} {
456
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num- warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
445
457
tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
446
458
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
447
459
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
@@ -472,7 +484,7 @@ def async_tma_blackwell_kernel(input_desc, XBLOCK: ttgl.constexpr, smem_layout:
472
484
bar = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], mbarrier .MBarrierLayout ())
473
485
mbarrier .init (bar , count = 1 )
474
486
475
- offset_layout : tl .constexpr = ttgl .BlockedLayout ([1 , 4 ], [32 , 1 ], [1 , 4 ], [1 , 0 ])
487
+ offset_layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 4 ], [32 , 1 ], [1 , 4 ], [1 , 0 ])
476
488
x_offsets = ttgl .arange (0 , XBLOCK , layout = ttgl .SliceLayout (0 , offset_layout ))
477
489
tma .async_gather (input_desc , x_offsets , 0 , bar , smem )
478
490
mbarrier .expect (bar , XBLOCK * XBLOCK * ttgl .float16 .primitive_bitwidth // 8 )
@@ -495,13 +507,13 @@ def test_async_tma_blackwell(fresh_knobs):
495
507
496
508
h = async_tma_blackwell_kernel .warmup (input_desc , XBLOCK , shared_layout , grid = (1 , ), num_warps = 4 )
497
509
expecttest .assert_expected_inline (
498
- h .asm ["source" ], """\
510
+ anonymize_ir ( h .asm ["source" ]) , """\
499
511
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
500
512
#loc = loc(unknown)
501
513
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
502
514
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
503
515
#smem = #ttg.shared_memory
504
- module attributes {"ttg.num-warps" = 4 : i32} {
516
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num- warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
505
517
tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
506
518
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
507
519
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
@@ -546,9 +558,9 @@ def tmem_subslice_kernel():
546
558
547
559
def test_tmem_subslice_constexpr ():
548
560
expecttest .assert_expected_inline (
549
- run_parser (tmem_subslice_kernel ).str_nodebug (), """\
561
+ anonymize_ir ( run_parser (tmem_subslice_kernel ).str_nodebug () ), """\
550
562
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
551
- module {
563
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
552
564
tt.func public @tmem_subslice_kernel() attributes {noinline = false} {
553
565
%result = ttng.tmem_alloc : () -> !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable>
554
566
%c0_i32 = arith.constant 0 : i32
@@ -574,10 +586,10 @@ def kernel():
574
586
smem_and_layout_user (smem , a )
575
587
576
588
expecttest .assert_expected_inline (
577
- run_parser (kernel ).str_nodebug (), """\
589
+ anonymize_ir ( run_parser (kernel ).str_nodebug () ), """\
578
590
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
579
591
#smem = #ttg.shared_memory
580
- module {
592
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
581
593
tt.func public @kernel() attributes {noinline = false} {
582
594
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
583
595
tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
0 commit comments