@@ -187,3 +187,62 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
187
187
tt.return
188
188
}
189
189
}
190
+
191
+ // -----
192
+
193
+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 1 ], repCluster = [2 , 2 ]}>
194
+ module attributes {ttig.support_sg_2d_block , " ttg.num-warps" = 8 : i32 } {
195
+ // CHECK-LABEL: @regular_pointer_block_io
196
+ tt.func public @regular_pointer_block_io (%arg0: tensor <256 x64 x!tt.ptr <f16 >, #mma >) {
197
+
198
+ %a_mask = arith.constant dense <true > : tensor <256 x64 xi1 , #mma >
199
+ %a_other = arith.constant dense <0.00e+00 > : tensor <256 x64 xf16 , #mma >
200
+ // CHECK-NOT: llvm.cond_br
201
+
202
+ // CHECK: %[[TOP_LEFT_MASK_BOOL_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i1, i1, {{.*}}
203
+ // CHECK: %[[TOP_LEFT_MASK_BOOL_32:.*]] = llvm.extractvalue {{.*}}[32] : !llvm.struct<(i1, i1, {{.*}}
204
+ // CHECK: %[[TOP_LEFT_MASK_BOOL_64:.*]] = llvm.extractvalue {{.*}}[64] : !llvm.struct<(i1, i1, {{.*}}
205
+ // CHECK: %[[TOP_LEFT_MASK_BOOL_96:.*]] = llvm.extractvalue {{.*}}[96] : !llvm.struct<(i1, i1, {{.*}}
206
+
207
+
208
+ // CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32
209
+ // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
210
+ // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
211
+ // CHECK: %[[TOP_LEFT_MASK_0:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_0]] : i1 to i8
212
+ // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_0]], %[[CST0_1]])
213
+ // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
214
+ // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
215
+ // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
216
+ // CHECK: llvm.select {{.*}}, %[[LOAD_0]], {{.*}} : i1, vector<32xf16>
217
+
218
+ // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
219
+ // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
220
+ // CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64]] : i1 to i8
221
+ // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_1]], %[[CST0_1]])
222
+ // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
223
+ // CHECK: %[[BASE_Y_1:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
224
+ // CHECK: %[[LOAD_1:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_1]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
225
+ // CHECK: llvm.select {{.*}}, %[[LOAD_1]], {{.*}} : i1, vector<32xf16>
226
+
227
+ // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
228
+ // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
229
+ // CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32]] : i1 to i8
230
+ // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_2]], %[[CST0_1]])
231
+ // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
232
+ // CHECK: %[[BASE_Y_2:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
233
+ // CHECK: %[[LOAD_2:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_2]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
234
+ // CHECK: llvm.select {{.*}}, %[[LOAD_2]], {{.*}} : i1, vector<32xf16>
235
+
236
+ // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
237
+ // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
238
+ // CHECK: %[[TOP_LEFT_MASK_3:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_96]] : i1 to i8
239
+ // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_3]], %[[CST0_1]])
240
+ // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
241
+ // CHECK: %[[BASE_Y_3:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
242
+ // CHECK: %[[LOAD_3:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_3]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
243
+ // CHECK: llvm.select {{.*}}, %[[LOAD_3]], {{.*}} : i1, vector<32xf16>
244
+ %0 = tt.load %arg0 , %a_mask , %a_other {ttig.block_io = " row_major" } : tensor <256 x64 x!tt.ptr <f16 >, #mma >
245
+
246
+ tt.return
247
+ }
248
+ }
0 commit comments