@@ -1347,6 +1347,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
13471347
13481348// -----
13491349
1350+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
1351+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
1352+ #dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#blocked }>
1353+ #dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#blocked }>
1354+ #smem = #ttg.shared_memory
1355+ module attributes {" ttg.target" = " cuda:70" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
1356+ // CHECK-LABEL: matmul_fmadot_integer
1357+ tt.func @matmul_fmadot_integer (%ptr: !tt.ptr <i32 > {tt.divisibility = 16 : i32 },
1358+ %a: !ttg.memdesc <32 x16 xi32 , #shared , #smem >, %b: !ttg.memdesc <16 x32 xi32 , #shared , #smem >) {
1359+ %cst = arith.constant dense <0 > : tensor <32 x32 xi32 , #blocked >
1360+ // CHECK-NOT: llvm.intr.fmuladd
1361+ // CHECK: llvm.mul
1362+ // CHECK: llvm.add
1363+ %a_mat = ttg.local_load %a : !ttg.memdesc <32 x16 xi32 , #shared , #smem > -> tensor <32 x16 xi32 , #dot_operand_a >
1364+ %b_mat = ttg.local_load %b : !ttg.memdesc <16 x32 xi32 , #shared , #smem > -> tensor <16 x32 xi32 , #dot_operand_b >
1365+
1366+ %28 = tt.dot %a_mat , %b_mat , %cst , inputPrecision = ieee : tensor <32 x16 xi32 , #dot_operand_a > * tensor <16 x32 xi32 , #dot_operand_b > -> tensor <32 x32 xi32 , #blocked >
1367+ %30 = tt.splat %ptr : !tt.ptr <i32 > -> tensor <32 x1 x!tt.ptr <i32 >, #blocked >
1368+ %36 = tt.broadcast %30 : tensor <32 x1 x!tt.ptr <i32 >, #blocked > -> tensor <32 x32 x!tt.ptr <i32 >, #blocked >
1369+ tt.store %36 , %28 : tensor <32 x32 x!tt.ptr <i32 >, #blocked >
1370+ tt.return
1371+ }
1372+ }
1373+
1374+ // -----
1375+
13501376#mma = #ttg.nvidia_mma <{versionMajor =2 , warpsPerCTA =[2 , 2 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [0 , 1 ], instrShape = [16 , 8 ]}>
13511377#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 4 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
13521378#blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
@@ -2257,6 +2283,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
22572283
22582284// -----
22592285
2286+ #mma = #ttg.nvidia_mma <{versionMajor = 2 , versionMinor = 0 , warpsPerCTA = [2 , 4 ], instrShape = [16 , 8 ]}>
2287+ module attributes {" ttg.num-warps" = 8 : i32 , ttg.target = " cuda:120" } {
2288+ // CHECK-LABEL: mmav2_e5m2_e5m2_fp16
2289+ tt.func public @mmav2_e5m2_e5m2_fp16 (%arg0: tensor <32 x32 xf8 E5 M2 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>, %arg1: tensor <32 x32 xf8 E5 M2 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>, %arg2: tensor <32 x32 xf16 , #mma >) {
2290+ // CHECK: mma.{{.*}}.col.f16.e5m2.e5m2.f16
2291+ %0 = tt.dot %arg0 , %arg1 , %arg2 {maxNumImpreciseAcc = 1073741824 : i32 } : tensor <32 x32 xf8 E5 M2 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <32 x32 xf8 E5 M2 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <32 x32 xf16 , #mma >
2292+ tt.return
2293+ }
2294+
2295+ // CHECK-LABEL: mmav2_e5m2_e4m3_fp16
2296+ tt.func public @mmav2_e5m2_e4m3_fp16 (%arg0: tensor <32 x32 xf8 E5 M2 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>, %arg1: tensor <32 x32 xf8 E4 M3 FN, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>, %arg2: tensor <32 x32 xf16 , #mma >) {
2297+ // CHECK: mma.{{.*}}.col.f16.e5m2.e4m3.f16
2298+ %0 = tt.dot %arg0 , %arg1 , %arg2 {maxNumImpreciseAcc = 1073741824 : i32 } : tensor <32 x32 xf8 E5 M2 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <32 x32 xf8 E4 M3 FN, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <32 x32 xf16 , #mma >
2299+ tt.return
2300+ }
2301+
2302+ // CHECK-LABEL: mmav2_e4m3_e5m2_fp16
2303+ tt.func public @mmav2_e4m3_e5m2_fp16 (%arg0: tensor <32 x32 xf8 E4 M3 FN, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>, %arg1: tensor <32 x32 xf8 E5 M2 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>, %arg2: tensor <32 x32 xf16 , #mma >) {
2304+ // CHECK: mma.{{.*}}.col.f16.e4m3.e5m2.f16
2305+ %0 = tt.dot %arg0 , %arg1 , %arg2 {maxNumImpreciseAcc = 1073741824 : i32 } : tensor <32 x32 xf8 E4 M3 FN, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <32 x32 xf8 E5 M2 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <32 x32 xf16 , #mma >
2306+ tt.return
2307+ }
2308+
2309+ // CHECK-LABEL: mmav2_e4m3_e4m3_fp16
2310+ tt.func public @mmav2_e4m3_e4m3_fp16 (%arg0: tensor <32 x32 xf8 E4 M3 FN, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>, %arg1: tensor <32 x32 xf8 E4 M3 FN, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>, %arg2: tensor <32 x32 xf16 , #mma >) {
2311+ // CHECK: mma.{{.*}}.col.f16.e4m3.e4m3.f16
2312+ %0 = tt.dot %arg0 , %arg1 , %arg2 {maxNumImpreciseAcc = 1073741824 : i32 } : tensor <32 x32 xf8 E4 M3 FN, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <32 x32 xf8 E4 M3 FN, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <32 x32 xf16 , #mma >
2313+ tt.return
2314+ }
2315+ }
2316+
2317+ // -----
2318+
22602319#blocked = #ttg.blocked <{sizePerThread = [1 , 1 , 16 ], threadsPerWarp = [4 , 4 , 2 ], warpsPerCTA = [8 , 1 , 1 ], order = [2 , 1 , 0 ]}>
22612320#linear = #ttg.linear <{register = [[0 , 0 ], [0 , 0 ], [0 , 0 ], [0 , 0 ]], lane = [[0 , 0 ], [0 , 1 ], [0 , 2 ], [1 , 0 ], [2 , 0 ]], warp = [[4 , 0 ], [8 , 0 ], [16 , 0 ]], block = []}>
22622321
0 commit comments