@@ -187,3 +187,62 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
187187 tt.return
188188 }
189189}
190+
191+ // -----
192+
193+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 1 ], repCluster = [2 , 2 ]}>
194+ module attributes {ttig.support_sg_2d_block , " ttg.num-warps" = 8 : i32 } {
195+ // CHECK-LABEL: @regular_pointer_block_io
196+ tt.func public @regular_pointer_block_io (%arg0: tensor <256 x64 x!tt.ptr <f16 >, #mma >) {
197+
198+ %a_mask = arith.constant dense <true > : tensor <256 x64 xi1 , #mma >
199+ %a_other = arith.constant dense <0.00e+00 > : tensor <256 x64 xf16 , #mma >
200+ // CHECK-NOT: llvm.cond_br
201+
202+ // CHECK: %[[TOP_LEFT_MASK_BOOL_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i1, i1, {{.*}}
203+ // CHECK: %[[TOP_LEFT_MASK_BOOL_32:.*]] = llvm.extractvalue {{.*}}[32] : !llvm.struct<(i1, i1, {{.*}}
204+ // CHECK: %[[TOP_LEFT_MASK_BOOL_64:.*]] = llvm.extractvalue {{.*}}[64] : !llvm.struct<(i1, i1, {{.*}}
205+ // CHECK: %[[TOP_LEFT_MASK_BOOL_96:.*]] = llvm.extractvalue {{.*}}[96] : !llvm.struct<(i1, i1, {{.*}}
206+
207+
208+ // CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32
209+ // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
210+ // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
211+ // CHECK: %[[TOP_LEFT_MASK_0:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_0]] : i1 to i8
212+ // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_0]], %[[CST0_1]])
213+ // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
214+ // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
215+ // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
216+ // CHECK: llvm.select {{.*}}, %[[LOAD_0]], {{.*}} : i1, vector<32xf16>
217+
218+ // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
219+ // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
220+ // CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64]] : i1 to i8
221+ // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_1]], %[[CST0_1]])
222+ // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
223+ // CHECK: %[[BASE_Y_1:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
224+ // CHECK: %[[LOAD_1:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_1]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
225+ // CHECK: llvm.select {{.*}}, %[[LOAD_1]], {{.*}} : i1, vector<32xf16>
226+
227+ // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
228+ // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
229+ // CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32]] : i1 to i8
230+ // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_2]], %[[CST0_1]])
231+ // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
232+ // CHECK: %[[BASE_Y_2:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
233+ // CHECK: %[[LOAD_2:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_2]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
234+ // CHECK: llvm.select {{.*}}, %[[LOAD_2]], {{.*}} : i1, vector<32xf16>
235+
236+ // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
237+ // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
238+ // CHECK: %[[TOP_LEFT_MASK_3:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_96]] : i1 to i8
239+ // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_3]], %[[CST0_1]])
240+ // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
241+ // CHECK: %[[BASE_Y_3:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
242+ // CHECK: %[[LOAD_3:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_3]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
243+ // CHECK: llvm.select {{.*}}, %[[LOAD_3]], {{.*}} : i1, vector<32xf16>
244+ %0 = tt.load %arg0 , %a_mask , %a_other {ttig.block_io = " row_major" } : tensor <256 x64 x!tt.ptr <f16 >, #mma >
245+
246+ tt.return
247+ }
248+ }
0 commit comments