@@ -246,3 +246,48 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
246
246
tt.return
247
247
}
248
248
}
249
+
250
+ // -----
251
+
252
+ // COM: Check codegen when base height is 1 and tile height is > 1.
253
+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 1 , threadsPerWarp = 16 , warpsPerCTA = [4 , 2 ], repCluster = [2 , 1 ], A = [16 , 8 ], B = [8 , 16 ], C = [16 , 16 ]}>
254
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , " ttg.threads-per-warp" = 16 : i32 , ttig.support_sg_2d_block } {
255
+ // CHECK-LABEL: @baseheight1
256
+ tt.func public @baseheight1 (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }) {
257
+ %18 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>}>>
258
+ %19 = tt.expand_dims %18 {axis = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>}>> -> tensor <1 x32 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
259
+ %20 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <1 x32 x!tt.ptr <f32 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
260
+ %21 = tt.addptr %20 , %19 : tensor <1 x32 x!tt.ptr <f32 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>, tensor <1 x32 xi32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
261
+ %22 = tt.broadcast %21 : tensor <1 x32 x!tt.ptr <f32 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> -> tensor <64 x32 x!tt.ptr <f32 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
262
+ %50 = tt.load %22 {ttig.block_io = " row_major" } : tensor <64 x32 x!tt.ptr <f32 >, #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
263
+ // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
264
+ // CHECK: [[LOAD:%.*]] = triton_gen.2Dblockload %{{.*}}, %{{.*}}, [[C1]], %{{.*}}, %{{.*}}, %{{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 16, v_blocks = 2
265
+
266
+ // CHECK: [[VEC:%.*]] = llvm.mlir.undef : vector<2xi32>
267
+
268
+ // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
269
+ // CHECK: [[OLDVAL:%.*]] = llvm.extractelement [[LOAD]][[[C0]] : i32] : vector<16xi32>
270
+ // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
271
+ // CHECK: [[THREADID_i64:%.*]] = llvm.call spir_funccc @_Z12get_local_idj([[C0]])
272
+ // CHECK: [[THREADID:%.*]] = llvm.trunc [[THREADID_i64]] : i64 to i32
273
+ // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
274
+ // CHECK: [[REM:%.*]] = llvm.urem [[THREADID]], [[C8]] : i32
275
+ // CHECK: [[NEWVAL:%.*]] = llvm.call spir_funccc @_Z17sub_group_shuffleij([[OLDVAL]], [[REM]])
276
+ // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
277
+ // CHECK: [[VEC1:%.*]] = llvm.insertelement [[NEWVAL]], [[VEC]][[[C0]] : i32] : vector<2xi32>
278
+
279
+ // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
280
+ // CHECK: [[OLDVAL:%.*]] = llvm.extractelement [[LOAD]][[[C8]] : i32] : vector<16xi32>
281
+ // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : i32
282
+ // CHECK: [[THREADID_i64:%.*]] = llvm.call spir_funccc @_Z12get_local_idj([[C0]])
283
+ // CHECK: [[THREADID:%.*]] = llvm.trunc [[THREADID_i64]] : i64 to i32
284
+ // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
285
+ // CHECK: [[REM:%.*]] = llvm.urem [[THREADID]], [[C8]] : i32
286
+ // CHECK: [[NEWVAL:%.*]] = llvm.call spir_funccc @_Z17sub_group_shuffleij([[OLDVAL]], [[REM]])
287
+ // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
288
+ // CHECK: [[VEC2:%.*]] = llvm.insertelement [[NEWVAL]], [[VEC1]][[[C1]] : i32] : vector<2xi32>
289
+
290
+ // CHECK: llvm.shufflevector [[VEC2]], [[VEC2]] [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
291
+ tt.return
292
+ }
293
+ }
0 commit comments