@@ -57,7 +57,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
5757 %65 = tt.splat %64 : i32 -> tensor <1 x64 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
5858 %66 = arith.cmpi slt , %38 , %65 : tensor <1 x64 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
5959 %67 = tt.broadcast %66 : tensor <1 x64 xi1 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> -> tensor <128 x64 xi1 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
60- // CHECK-COUNT-16 : triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8 , v_blocks = 2
60+ // CHECK-COUNT-8 : triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32 , v_blocks = 1
6161 %68 = tt.load %60 , %67 , %cst_3 {ttig.block_io = " row_major" } : tensor <128 x64 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
6262 %74 = tt.addptr %60 , %cst_0 : tensor <128 x64 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>, tensor <128 x64 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
6363 %76 = arith.addi %58 , %c1_i32 : i32
@@ -69,72 +69,6 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
6969
7070// -----
7171
72- #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [2 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
73- module attributes {ttig.min_sg_size = 16 : i32 , ttig.support_bf16_conversion , ttig.support_dpas , ttig.support_sg_2d_block , ttig.target_arch = " spir64" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.shared = 33280 : i32 , ttg.target = " xpu" , " ttg.threads-per-warp" = 16 : i32 } {
74- tt.func public @matmul_tensor_pointer_kernel (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: !llvm.ptr <3 >) attributes {noinline = false } {
75- %c63_i32 = arith.constant 63 : i32
76- %c255_i32 = arith.constant 255 : i32
77- %c127_i32 = arith.constant 127 : i32
78- %c1_i32 = arith.constant 1 : i32
79- %c0_i32 = arith.constant 0 : i32
80- %c64_i32 = arith.constant 64 : i32
81- %c8_i32 = arith.constant 8 : i32
82- %c128_i32 = arith.constant 128 : i32
83- %c256_i32 = arith.constant 256 : i32
84- %cst_1 = arith.constant dense <0 > : tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>
85- %cst_4 = arith.constant dense <0.000000e+00 > : tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
86- %0 = tt.get_program_id x : i32
87- %1 = arith.addi %arg3 , %c127_i32 : i32
88- %2 = arith.divsi %1 , %c128_i32 : i32
89- %3 = arith.addi %arg4 , %c255_i32 : i32
90- %4 = arith.divsi %3 , %c256_i32 : i32
91- %5 = arith.muli %4 , %c8_i32 : i32
92- %6 = arith.divsi %0 , %5 : i32
93- %7 = arith.muli %6 , %c8_i32 : i32
94- %8 = arith.subi %2 , %7 : i32
95- %9 = arith.minsi %8 , %c8_i32 : i32
96- %12 = arith.remsi %0 , %5 : i32
97- %13 = arith.divsi %12 , %9 : i32
98- %15 = arith.muli %13 , %c256_i32 : i32
99- %22 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>
100- %24 = tt.splat %15 : i32 -> tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>
101- %26 = arith.addi %24 , %22 : tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>%31 = tt.splat %arg4 : i32 -> tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>
102- %44 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>>
103- %45 = tt.expand_dims %44 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>> -> tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
104- %cst_2 = arith.constant dense <512 > : tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
105- %47 = arith.muli %45 , %cst_2 : tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
106- %48 = tt.expand_dims %26 {axis = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>}>> -> tensor <1 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
107- %49 = tt.broadcast %47 : tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
108- %50 = tt.broadcast %48 : tensor <1 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
109- %51 = arith.addi %49 , %50 : tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
110- %52 = tt.splat %arg1 : !tt.ptr <f16 > -> tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
111- %53 = tt.addptr %52 , %51 : tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
112- %54 = arith.addi %arg5 , %c63_i32 : i32
113- %55 = arith.divsi %54 , %c64_i32 : i32
114- %56 = arith.muli %arg7 , %c64_i32 : i32
115- %57 = tt.splat %56 : i32 -> tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
116- cf.br ^bb1 (%c0_i32 , %53 : i32 , tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>)
117- ^bb1 (%58: i32 , %61: tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>): // 2 preds: ^bb0, ^bb2
118- %62 = arith.cmpi slt , %58 , %55 : i32
119- cf.cond_br %62 , ^bb2 , ^bb3
120- ^bb2 : // pred: ^bb1
121- %63 = arith.muli %58 , %c64_i32 : i32
122- %64 = arith.subi %arg5 , %63 : i32
123- %69 = tt.splat %64 : i32 -> tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
124- %70 = arith.cmpi slt , %45 , %69 : tensor <64 x1 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
125- %71 = tt.broadcast %70 : tensor <64 x1 xi1 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x256 xi1 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
126- // CHECK-COUNT-8: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 32, v_blocks = 1
127- %72 = tt.load %61 , %71 , %cst_4 {ttig.block_io = " row_major" } : tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
128- %75 = tt.addptr %61 , %57 : tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, tensor <64 x256 xi32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
129- %76 = arith.addi %58 , %c1_i32 : i32
130- cf.br ^bb1 (%76 , %75 : i32 , tensor <64 x256 x!tt.ptr <f16 >, #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>)
131- ^bb3 : // pred: ^bb1
132- tt.return
133- }
134- }
135-
136- // -----
137-
13872#mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 1 ], repCluster = [2 , 2 ]}>
13973#mma_1 = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [4 , 2 ], repCluster = [1 , 1 ]}>
14074#mma_2 = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 1 ], repCluster = [4 , 2 ]}>
@@ -259,6 +193,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
259193
260194
261195 // CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32
196+ // CHECK-COUNT-2: llvm.mlir.constant(0 : i32) : i32
262197 // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
263198 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
264199 // CHECK: %[[TOP_LEFT_MASK_0:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_0]] : i1 to i8
@@ -267,22 +202,25 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
267202 // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
268203 // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
269204
205+ // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32
270206 // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
271207 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
272- // CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64 ]] : i1 to i8
208+ // CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32 ]] : i1 to i8
273209 // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_1]], %[[CST0_1]])
274210 // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
275211 // CHECK: %[[BASE_Y_1:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
276212 // CHECK: %[[LOAD_1:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_1]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
277213
214+ // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32
278215 // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
279216 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
280- // CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32 ]] : i1 to i8
217+ // CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64 ]] : i1 to i8
281218 // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_2]], %[[CST0_1]])
282219 // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
283220 // CHECK: %[[BASE_Y_2:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
284221 // CHECK: %[[LOAD_2:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_2]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
285222
223+ // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32
286224 // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
287225 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
288226 // CHECK: %[[TOP_LEFT_MASK_3:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_96]] : i1 to i8
0 commit comments