@@ -41,21 +41,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
41
41
#tmem = #ttng.tensor_memory_encoding <blockM = 64 , blockN = 64 , unpacked = true >
42
42
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 } {
43
43
// CHECK-LABEL: @tc_gen5_mma_multi_m_n
44
- // CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
45
- // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
46
- // CHECK-DAG: %[[C64:.+]] = llvm.mlir.constant(64 : i32) : i32
47
- // CHECK-DAG: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
48
- // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T0]]
49
- // CHECK: %[[T1:.+]] = llvm.add %[[TMEM_BASE]], %[[C64]] : i32
50
- // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T1]]
44
+ // CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
45
+ // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
46
+ // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 64 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
51
47
// 1048576 = row << 16 + col = 16 << 16 + 0
52
- // CHECK: %[[C1048576:.+]] = llvm.mlir.constant(1048576 : i32) : i32
53
- // CHECK: %[[T2:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048576]] : i32
54
- // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T2]]
48
+ // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 1048576 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
55
49
// 1048640 = row << 16 + col = 16 << 16 + 64
56
- // CHECK: %[[C1048640:.+]] = llvm.mlir.constant(1048640 : i32) : i32
57
- // CHECK: %[[T3:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048640]] : i32
58
- // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T3]]
50
+ // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 1048640 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
59
51
60
52
tt.func @tc_gen5_mma_multi_m_n (%a: !ttg.memdesc <128 x16 xf16 , #shared , #ttg.shared_memory >,
61
53
%b: !ttg.memdesc <16 x128 xf16 , #shared1 , #ttg.shared_memory >,
@@ -82,21 +74,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
82
74
#tmem = #ttng.tensor_memory_encoding <blockM = 64 , blockN = 32 , unpacked = true , CTASplitN = 2 >
83
75
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 } {
84
76
// CHECK-LABEL: @tc_gen5_mma_multi_ctas
85
- // CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
86
- // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
87
- // CHECK-DAG: %[[C32:.+]] = llvm.mlir.constant(32 : i32) : i32
88
- // CHECK-DAG: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
89
- // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T0]]
90
- // CHECK: %[[T1:.+]] = llvm.add %[[TMEM_BASE]], %[[C32]] : i32
91
- // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T1]]
77
+ // CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
78
+ // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
79
+ // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 32 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
92
80
// 1048576 = row << 16 + col = 16 << 16 + 0
93
- // CHECK: %[[C1048576:.+]] = llvm.mlir.constant(1048576 : i32) : i32
94
- // CHECK: %[[T2:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048576]] : i32
95
- // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T2]]
81
+ // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 1048576 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
96
82
// 1048640 = row << 16 + col = 16 << 16 + 32
97
- // CHECK: %[[C1048608:.+]] = llvm.mlir.constant(1048608 : i32) : i32
98
- // CHECK: %[[T3:.+]] = llvm.add %[[TMEM_BASE]], %[[C1048608]] : i32
99
- // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[T3]]
83
+ // CHECK: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 1048608 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %[[TMEM_BASE]]
100
84
101
85
tt.func @tc_gen5_mma_multi_ctas (%a: !ttg.memdesc <128 x16 xf16 , #shared , #ttg.shared_memory >,
102
86
%b: !ttg.memdesc <16 x128 xf16 , #shared1 , #ttg.shared_memory >,
@@ -203,12 +187,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
203
187
// CHECK: %[[P0:.+]] = llvm.icmp "eq" %[[WID]], %[[C0]] : i32
204
188
// CHECK: %[[P1:.+]] = llvm.and %{{.*}}, %[[P0]] : i1
205
189
// CHECK: llvm.cond_br %[[P1]]
206
- // CHECK: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
207
190
// CHECK: %[[DESC0:.+]] = llvm.mlir.constant(144708608 : i32) : i32
208
- // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0 ]], %{{.+}}, %{{.+}}, %[[DESC0]], %{{.+}}, %{{.+}}, %arg5
191
+ // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE ]], %{{.+}}, %{{.+}}, %[[DESC0]], %{{.+}}, %{{.+}}, %arg5
209
192
// CHECK: %[[TRUE:.+]] = llvm.mlir.constant(true) : i1
210
193
// CHECK: %[[DESC1:.+]] = llvm.mlir.constant(681579536 : i32) : i32
211
- // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0 ]], %{{.+}}, %{{.+}}, %[[DESC1]], %{{.+}}, %{{.+}}, %[[TRUE]]
194
+ // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE ]], %{{.+}}, %{{.+}}, %[[DESC1]], %{{.+}}, %{{.+}}, %[[TRUE]]
212
195
tt.func @tc_gen5_mma_block_scale (%a: !ttg.memdesc <128 x64 xi8 , #shared , #ttg.shared_memory >,
213
196
%b: !ttg.memdesc <32 x128 xi8 , #shared1 , #ttg.shared_memory >,
214
197
%c: !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >,
@@ -320,12 +303,10 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.thr
320
303
#tmem_scales = #ttng.tensor_memory_scales_encoding <>
321
304
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
322
305
// CHECK-LABEL: @tc_gen5_mma_block_scale_nvfp4
323
- // CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
324
- // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
325
- // CHECK: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
306
+ // CHECK: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
326
307
// CHECK: %[[DESC0:.+]] = llvm.mlir.constant(138413184 : i32) : i32
327
- // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0 ]], %{{.+}}, %{{.+}}, %[[DESC0]]
328
- // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0 ]], %{{.+}}, %{{.+}}, %[[DESC0]]
308
+ // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE ]], %{{.+}}, %{{.+}}, %[[DESC0]]
309
+ // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE ]], %{{.+}}, %{{.+}}, %[[DESC0]]
329
310
tt.func @tc_gen5_mma_block_scale_nvfp4 (%a: !ttg.memdesc <128 x64 xi8 , #shared , #ttg.shared_memory >,
330
311
%b: !ttg.memdesc <64 x256 xi8 , #shared1 , #ttg.shared_memory >,
331
312
%c: !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >,
@@ -356,12 +337,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
356
337
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
357
338
// CHECK-LABEL: @tc_gen5_mma_block_scale_mxfp4
358
339
// CHECK-DAG: %[[TMEM_BASE:.+]] = llvm.ptrtoint %{{.*}} : !llvm.ptr<3> to i32
359
- // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
360
- // CHECK: %[[T0:.+]] = llvm.add %[[TMEM_BASE]], %[[C0]] : i32
361
340
// CHECK: %[[DESC0:.+]] = llvm.mlir.constant(146801792 : i32) : i32
362
- // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0 ]], %{{.+}}, %{{.+}}, %[[DESC0]]
341
+ // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE ]], %{{.+}}, %{{.+}}, %[[DESC0]]
363
342
// CHECK: %[[DESC1:.+]] = llvm.mlir.constant(1220543648 : i32) : i32
364
- // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[T0 ]], %{{.+}}, %{{.+}}, %[[DESC1]]
343
+ // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %[[TMEM_BASE ]], %{{.+}}, %{{.+}}, %[[DESC1]]
365
344
tt.func @tc_gen5_mma_block_scale_mxfp4 (%a: !ttg.memdesc <128 x64 xi8 , #shared , #ttg.shared_memory >,
366
345
%b: !ttg.memdesc <64 x256 xi8 , #shared1 , #ttg.shared_memory >,
367
346
%c: !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >,
0 commit comments