@@ -165,8 +165,8 @@ def tensor_memory_kernel(layout: ttgl.constexpr, tmem_layout: ttgl.constexpr):
165
165
slice2 = mem .slice (YBLOCK // 2 , YBLOCK // 2 ) # noqa: F841
166
166
167
167
buffers = ttgl .nvidia .blackwell .allocate_tensor_memory (ttgl .float32 , [2 , XBLOCK , YBLOCK ], tmem_layout )
168
- for i in range (2 ):
169
- buffers .index (i ).load (layout )
168
+ for ivar in range (2 ):
169
+ buffers .index (ivar ).load (layout )
170
170
171
171
172
172
@pytest .mark .skipif (not is_blackwell (), reason = "Requires blackwell tensor cores" )
@@ -200,9 +200,9 @@ def test_tensor_memory(fresh_knobs):
200
200
%3 = arith.bitcast %c2_i32 : i32 to i32 loc(#loc)
201
201
%4 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc)
202
202
%5 = ub.poison : i32 loc(#loc)
203
- scf.for %arg0 = %2 to %3 step %4 : i32 {
203
+ scf.for %ivar = %2 to %3 step %4 : i32 {
204
204
%c0_i32_4 = arith.constant 0 : i32 loc(#loc)
205
- %6 = ttg.memdesc_subview %result_2[%arg0 , %c0_i32_4, %c0_i32_4] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> loc(#loc)
205
+ %6 = ttg.memdesc_subview %result_2[%ivar , %c0_i32_4, %c0_i32_4] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> loc(#loc)
206
206
%result_5 = ttng.tmem_load %6 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> loc(#loc)
207
207
} loc(#loc)
208
208
tt.return loc(#loc)
@@ -257,8 +257,8 @@ def test_shared_memory_subview(fresh_knobs):
257
257
@gluon .jit
258
258
def shared_memory_index_kernel (XBLOCK : ttgl .constexpr , layout : ttgl .constexpr , smem_layout : ttgl .constexpr ):
259
259
smem = ttgl .allocate_shared_memory (ttgl .int32 , [4 , XBLOCK ], smem_layout )
260
- for i in range (4 ):
261
- smem .index (i ).load (layout )
260
+ for ivar in range (4 ):
261
+ smem .index (ivar ).load (layout )
262
262
263
263
264
264
@pytest .mark .skipif (not is_cuda (), reason = "Requires CUDA" )
@@ -283,9 +283,9 @@ def test_shared_memory_index(fresh_knobs):
283
283
%2 = arith.bitcast %c4_i32 : i32 to i32 loc(#loc)
284
284
%3 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc)
285
285
%4 = ub.poison : i32 loc(#loc)
286
- scf.for %arg0 = %1 to %2 step %3 : i32 {
286
+ scf.for %ivar = %1 to %2 step %3 : i32 {
287
287
%c0_i32_0 = arith.constant 0 : i32 loc(#loc)
288
- %5 = ttg.memdesc_subview %0[%arg0 , %c0_i32_0] : !ttg.memdesc<4x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> loc(#loc)
288
+ %5 = ttg.memdesc_subview %0[%ivar , %c0_i32_0] : !ttg.memdesc<4x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> loc(#loc)
289
289
%6 = ttg.local_load %5 : !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> -> tensor<256xi32, #blocked> loc(#loc)
290
290
} loc(#loc)
291
291
tt.return loc(#loc)
@@ -676,32 +676,33 @@ def test_async_tma(fresh_knobs):
676
676
h = async_tma_kernel .warmup (input_desc , XBLOCK , grid = (1 , ), num_warps = 4 )
677
677
expecttest .assert_expected_inline (
678
678
anonymize_ir (h .asm ["source" ]), """\
679
- #loc = loc(unknown )
679
+ #loc1 = loc("input_desc" )
680
680
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
681
681
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
682
682
#smem = #ttg.shared_memory
683
683
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
684
- tt.func public @async_tma_kernel(%arg0 : !tt.tensordesc<tensor<128x128xf16, #shared>> loc(unknown ), %arg1 : i32 loc(unknown ), %arg2 : i32 loc(unknown ), %arg3 : i64 loc(unknown ), %arg4 : i64 loc(unknown )) attributes {noinline = false} {
684
+ tt.func public @async_tma_kernel(%input_desc : !tt.tensordesc<tensor<128x128xf16, #shared>> loc("input_desc" ), %input_desc_0 : i32 loc("input_desc" ), %input_desc_1 : i32 loc("input_desc" ), %input_desc_2 : i64 loc("input_desc" ), %input_desc_3 : i64 loc("input_desc" )) attributes {noinline = false} {
685
685
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
686
686
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
687
687
ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
688
688
%c0_i32 = arith.constant 0 : i32 loc(#loc)
689
- %c0_i32_0 = arith.constant 0 : i32 loc(#loc)
689
+ %c0_i32_4 = arith.constant 0 : i32 loc(#loc)
690
690
%true = arith.constant true loc(#loc)
691
- ttng.async_tma_copy_global_to_local %arg0 [%c0_i32, %c0_i32_0 ] %0, %1, %true : !tt.tensordesc<tensor<128x128xf16, #shared>>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
692
- %true_1 = arith.constant true loc(#loc)
693
- ttng.barrier_expect %1, 32768, %true_1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
694
- %c0_i32_2 = arith.constant 0 : i32 loc(#loc)
695
- %true_3 = arith.constant true loc(#loc)
696
- ttng.wait_barrier %1, %c0_i32_2 , %true_3 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
691
+ ttng.async_tma_copy_global_to_local %input_desc [%c0_i32, %c0_i32_4 ] %0, %1, %true : !tt.tensordesc<tensor<128x128xf16, #shared>>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
692
+ %true_5 = arith.constant true loc(#loc)
693
+ ttng.barrier_expect %1, 32768, %true_5 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
694
+ %c0_i32_6 = arith.constant 0 : i32 loc(#loc)
695
+ %true_7 = arith.constant true loc(#loc)
696
+ ttng.wait_barrier %1, %c0_i32_6 , %true_7 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
697
697
ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
698
- %c0_i32_4 = arith.constant 0 : i32 loc(#loc)
699
- %c0_i32_5 = arith.constant 0 : i32 loc(#loc)
700
- ttng.async_tma_copy_local_to_global %arg0[%c0_i32_4 , %c0_i32_5 ] %0 : !tt.tensordesc<tensor<128x128xf16, #shared>>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
698
+ %c0_i32_8 = arith.constant 0 : i32 loc(#loc)
699
+ %c0_i32_9 = arith.constant 0 : i32 loc(#loc)
700
+ ttng.async_tma_copy_local_to_global %input_desc[%c0_i32_8 , %c0_i32_9 ] %0 : !tt.tensordesc<tensor<128x128xf16, #shared>>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
701
701
ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
702
702
tt.return loc(#loc)
703
703
} loc(#loc)
704
704
} loc(#loc)
705
+ #loc = loc(unknown)
705
706
""" )
706
707
707
708
@@ -736,31 +737,32 @@ def test_async_tma_blackwell(fresh_knobs):
736
737
expecttest .assert_expected_inline (
737
738
anonymize_ir (h .asm ["source" ]), """\
738
739
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
739
- #loc = loc(unknown )
740
+ #loc1 = loc("input_desc" )
740
741
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
741
742
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
742
743
#smem = #ttg.shared_memory
743
744
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
744
- tt.func public @async_tma_blackwell_kernel(%arg0 : !tt.tensordesc<tensor<1x128xf16, #shared>> loc(unknown ), %arg1 : i32 loc(unknown ), %arg2 : i32 loc(unknown ), %arg3 : i64 loc(unknown ), %arg4 : i64 loc(unknown )) attributes {noinline = false} {
745
+ tt.func public @async_tma_blackwell_kernel(%input_desc : !tt.tensordesc<tensor<1x128xf16, #shared>> loc("input_desc" ), %input_desc_0 : i32 loc("input_desc" ), %input_desc_1 : i32 loc("input_desc" ), %input_desc_2 : i64 loc("input_desc" ), %input_desc_3 : i64 loc("input_desc" )) attributes {noinline = false} {
745
746
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
746
747
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
747
748
ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
748
749
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
749
750
%true = arith.constant true loc(#loc)
750
751
%c0_i32 = arith.constant 0 : i32 loc(#loc)
751
- ttng.async_tma_gather %arg0 [%2, %c0_i32] %0, %1, %true : !tt.tensordesc<tensor<1x128xf16, #shared>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
752
- %true_0 = arith.constant true loc(#loc)
753
- ttng.barrier_expect %1, 32768, %true_0 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
754
- %c0_i32_1 = arith.constant 0 : i32 loc(#loc)
755
- %true_2 = arith.constant true loc(#loc)
756
- ttng.wait_barrier %1, %c0_i32_1 , %true_2 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
752
+ ttng.async_tma_gather %input_desc [%2, %c0_i32] %0, %1, %true : !tt.tensordesc<tensor<1x128xf16, #shared>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
753
+ %true_4 = arith.constant true loc(#loc)
754
+ ttng.barrier_expect %1, 32768, %true_4 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
755
+ %c0_i32_5 = arith.constant 0 : i32 loc(#loc)
756
+ %true_6 = arith.constant true loc(#loc)
757
+ ttng.wait_barrier %1, %c0_i32_5 , %true_6 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
757
758
ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
758
- %c0_i32_3 = arith.constant 0 : i32 loc(#loc)
759
- ttng.async_tma_scatter %arg0 [%2, %c0_i32_3 ] %0 : !tt.tensordesc<tensor<1x128xf16, #shared>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
759
+ %c0_i32_7 = arith.constant 0 : i32 loc(#loc)
760
+ ttng.async_tma_scatter %input_desc [%2, %c0_i32_7 ] %0 : !tt.tensordesc<tensor<1x128xf16, #shared>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
760
761
ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
761
762
tt.return loc(#loc)
762
763
} loc(#loc)
763
764
} loc(#loc)
765
+ #loc = loc(unknown)
764
766
""" )
765
767
766
768
@@ -972,8 +974,9 @@ def test_reduce(fresh_knobs):
972
974
anonymize_ir (h .asm ["ttgir" ]), """\
973
975
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
974
976
#loc = loc(unknown)
977
+ #loc1 = loc("out")
975
978
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
976
- tt.func public @reduce_kernel(%arg0 : !tt.ptr<f32> {tt.divisibility = 16 : i32} loc(unknown )) attributes {noinline = false} {
979
+ tt.func public @reduce_kernel(%out : !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("out" )) attributes {noinline = false} {
977
980
%cst = arith.constant dense<2.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
978
981
%cst_0 = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
979
982
%0 = "tt.reduce"(%cst_0) <{axis = 0 : i32}> ({
@@ -1003,7 +1006,7 @@ def test_reduce(fresh_knobs):
1003
1006
%7 = arith.addf %6, %4#0 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
1004
1007
%8 = arith.addf %7, %4#1 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
1005
1008
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
1006
- %10 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
1009
+ %10 = tt.splat %out : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
1007
1010
%11 = tt.addptr %10, %9 : tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
1008
1011
tt.store %11, %8 : tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
1009
1012
tt.return loc(#loc)
@@ -1202,16 +1205,17 @@ def test_async_copy(fresh_knobs):
1202
1205
expecttest .assert_expected_inline (
1203
1206
anonymize_ir (h .asm ["ttgir" ]), """\
1204
1207
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
1205
- #loc = loc(unknown)
1208
+ #loc1 = loc("inp")
1209
+ #loc2 = loc("xnumel")
1206
1210
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
1207
1211
#smem = #ttg.shared_memory
1208
1212
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
1209
- tt.func public @async_copy_kernel(%arg0 : !tt.ptr<f16> {tt.divisibility = 16 : i32} loc(unknown ), %arg1 : i32 loc(unknown )) attributes {noinline = false} {
1213
+ tt.func public @async_copy_kernel(%inp : !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("inp" ), %xnumel : i32 loc("xnumel" )) attributes {noinline = false} {
1210
1214
%0 = ttg.local_alloc : () -> !ttg.memdesc<128xf16, #shared, #smem, mutable> loc(#loc)
1211
1215
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> loc(#loc)
1212
- %2 = tt.splat %arg1 : i32 -> tensor<128xi32, #blocked> loc(#loc)
1216
+ %2 = tt.splat %xnumel : i32 -> tensor<128xi32, #blocked> loc(#loc)
1213
1217
%3 = arith.cmpi slt, %1, %2 {tt.constancy = dense<2> : tensor<1xi32>} : tensor<128xi32, #blocked> loc(#loc)
1214
- %4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x!tt.ptr<f16>, #blocked> loc(#loc)
1218
+ %4 = tt.splat %inp : !tt.ptr<f16> -> tensor<128x!tt.ptr<f16>, #blocked> loc(#loc)
1215
1219
%5 = tt.addptr %4, %1 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked> loc(#loc)
1216
1220
%6 = ttg.async_copy_global_to_local %5, %0 : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc)
1217
1221
%7 = ttg.async_copy_global_to_local %5, %0 mask %3 cacheModifier = ca evictionPolicy = evict_last {isVolatile = true} : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc)
@@ -1223,6 +1227,7 @@ def test_async_copy(fresh_knobs):
1223
1227
tt.return loc(#loc)
1224
1228
} loc(#loc)
1225
1229
} loc(#loc)
1230
+ #loc = loc(unknown)
1226
1231
""" )
1227
1232
1228
1233
0 commit comments