@@ -246,3 +246,48 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
246246 tt.return
247247 }
248248}
249+
250+ // -----
251+
252+ // COM: Check codegen when base height is 1 and tile height is > 1.
253+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 1 , threadsPerWarp = 16 , warpsPerCTA = [4 , 2 ], repCluster = [2 , 1 ], A = [16 , 8 ], B = [8 , 16 ], C = [16 , 16 ]}>
254+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , " ttg.threads-per-warp" = 16 : i32 , ttig.support_sg_2d_block } {
255+ // CHECK-LABEL: @baseheight1
256+ tt.func public @baseheight1 (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }) {
257+ %18 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>}>>
258+ %19 = tt.expand_dims %18 {axis = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>}>> -> tensor <1 x32 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
259+ %20 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <1 x32 x!tt.ptr <f32 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
260+ %21 = tt.addptr %20 , %19 : tensor <1 x32 x!tt.ptr <f32 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>, tensor <1 x32 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
261+ %22 = tt.broadcast %21 : tensor <1 x32 x!tt.ptr <f32 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> -> tensor <64 x32 x!tt.ptr <f32 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
262+ %50 = tt.load %22 {ttig.block_io = " row_major" } : tensor <64 x32 x!tt.ptr <f32 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
263+ // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
264+ // CHECK: [[LOAD:%.*]] = triton_gen.2Dblockload %{{.*}}, %{{.*}}, [[C1]], %{{.*}}, %{{.*}}, %{{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 2
265+
266+ // CHECK: [[VEC:%.*]] = llvm.mlir.undef : vector<2xi32>
267+
268+ // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
269+ // CHECK: [[OLDVAL:%.*]] = llvm.extractelement [[LOAD]][[[C0]] : i32] : vector<16xi32>
270+ // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
271+ // CHECK: [[THREADID_i64:%.*]] = llvm.call spir_funccc @_Z12get_local_idj([[C0]])
272+ // CHECK: [[THREADID:%.*]] = llvm.trunc [[THREADID_i64]] : i64 to i32
273+ // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
274+ // CHECK: [[REM:%.*]] = llvm.urem [[THREADID]], [[C8]] : i32
275+ // CHECK: [[NEWVAL:%.*]] = llvm.call spir_funccc @_Z17sub_group_shuffleij([[OLDVAL]], [[REM]])
276+ // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
277+ // CHECK: [[VEC1:%.*]] = llvm.insertelement [[NEWVAL]], [[VEC]][[[C0]] : i32] : vector<2xi32>
278+
279+ // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
280+ // CHECK: [[OLDVAL:%.*]] = llvm.extractelement [[LOAD]][[[C8]] : i32] : vector<16xi32>
281+ // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
282+ // CHECK: [[THREADID_i64:%.*]] = llvm.call spir_funccc @_Z12get_local_idj([[C0]])
283+ // CHECK: [[THREADID:%.*]] = llvm.trunc [[THREADID_i64]] : i64 to i32
284+ // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
285+ // CHECK: [[REM:%.*]] = llvm.urem [[THREADID]], [[C8]] : i32
286+ // CHECK: [[NEWVAL:%.*]] = llvm.call spir_funccc @_Z17sub_group_shuffleij([[OLDVAL]], [[REM]])
287+ // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
288+ // CHECK: [[VEC2:%.*]] = llvm.insertelement [[NEWVAL]], [[VEC1]][[[C1]] : i32] : vector<2xi32>
289+
290+ // CHECK: llvm.shufflevector [[VEC2]], [[VEC2]] [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
291+ tt.return
292+ }
293+ }
0 commit comments