Skip to content

Commit bbb2285

Browse files
authored
[BACKEND] Optimize the lowering of tt.load with masks (#4539)
Use the block IO hardware boundary protection capability to replace the branch in lowering tt.load with masks. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 93b2d49 commit bbb2285

File tree

2 files changed

+101
-47
lines changed

2 files changed

+101
-47
lines changed

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,62 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
187187
tt.return
188188
}
189189
}
190+
191+
// -----
192+
193+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2]}>
194+
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
195+
// CHECK-LABEL: @regular_pointer_block_io
196+
tt.func public @regular_pointer_block_io(%arg0: tensor<256x64x!tt.ptr<f16>, #mma>) {
197+
198+
%a_mask = arith.constant dense<true> : tensor<256x64xi1, #mma>
199+
%a_other = arith.constant dense<0.00e+00> : tensor<256x64xf16, #mma>
200+
// CHECK-NOT: llvm.cond_br
201+
202+
// CHECK: %[[TOP_LEFT_MASK_BOOL_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i1, i1, {{.*}}
203+
// CHECK: %[[TOP_LEFT_MASK_BOOL_32:.*]] = llvm.extractvalue {{.*}}[32] : !llvm.struct<(i1, i1, {{.*}}
204+
// CHECK: %[[TOP_LEFT_MASK_BOOL_64:.*]] = llvm.extractvalue {{.*}}[64] : !llvm.struct<(i1, i1, {{.*}}
205+
// CHECK: %[[TOP_LEFT_MASK_BOOL_96:.*]] = llvm.extractvalue {{.*}}[96] : !llvm.struct<(i1, i1, {{.*}}
206+
207+
208+
// CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32
209+
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
210+
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
211+
// CHECK: %[[TOP_LEFT_MASK_0:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_0]] : i1 to i8
212+
// CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_0]], %[[CST0_1]])
213+
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
214+
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
215+
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
216+
// CHECK: llvm.select {{.*}}, %[[LOAD_0]], {{.*}} : i1, vector<32xf16>
217+
218+
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
219+
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
220+
// CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64]] : i1 to i8
221+
// CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_1]], %[[CST0_1]])
222+
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
223+
// CHECK: %[[BASE_Y_1:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
224+
// CHECK: %[[LOAD_1:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_1]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
225+
// CHECK: llvm.select {{.*}}, %[[LOAD_1]], {{.*}} : i1, vector<32xf16>
226+
227+
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
228+
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
229+
// CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32]] : i1 to i8
230+
// CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_2]], %[[CST0_1]])
231+
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
232+
// CHECK: %[[BASE_Y_2:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
233+
// CHECK: %[[LOAD_2:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_2]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
234+
// CHECK: llvm.select {{.*}}, %[[LOAD_2]], {{.*}} : i1, vector<32xf16>
235+
236+
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
237+
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
238+
// CHECK: %[[TOP_LEFT_MASK_3:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_96]] : i1 to i8
239+
// CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_3]], %[[CST0_1]])
240+
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
241+
// CHECK: %[[BASE_Y_3:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
242+
// CHECK: %[[LOAD_3:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_3]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
243+
// CHECK: llvm.select {{.*}}, %[[LOAD_3]], {{.*}} : i1, vector<32xf16>
244+
%0 = tt.load %arg0, %a_mask, %a_other {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma>
245+
246+
tt.return
247+
}
248+
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,11 +1208,47 @@ struct LoadOpToBlockIOConversion
12081208
<< ", loadOuter:" << loadOuter << " offset: [" << offsetM
12091209
<< ", " << offsetN << "]");
12101210

1211-
Value pred =
1212-
masks.size() ? masks[{offsetM, offsetN}] : b.int_val(1, 1);
1213-
pred = targetInfo.shuffleIdx(rewriter, loc, pred, 0);
1214-
Value other_ = b.undef(load2DGenXType);
1211+
Value offsetY = b.i32_val(0);
1212+
Value pred;
1213+
if (llMask) {
1214+
assert(masks.size() && "Invalid size of the masks.");
1215+
pred = targetInfo.shuffleIdx(rewriter, loc,
1216+
masks[{offsetM, offsetN}], 0);
1217+
// We leverage the GPU block I/O hardware out-of-bound protection
1218+
// feature by setting the offset to an invalid value when 'pred'
1219+
// is false (the HW will not read out-of-bounds values). Later on,
1220+
// after issuing the 2d block read operation, we will select the
1221+
// result of the load only if the mask evaluate to true, otherwise
1222+
// we will use 'other'.
1223+
offsetY = b.select(pred, offsetY, baseHeight);
1224+
}
1225+
1226+
// Use the top-left address of the block to load the data.
1227+
Value addrElem =
1228+
b.bitcast(ptrs[{offsetM, offsetN}], ptr_ty(ctx, 1 /*global*/));
1229+
addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0);
1230+
1231+
Value ret = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
1232+
loc, load2DGenXType,
1233+
/*ptr*/ addrElem,
1234+
/*base_width*/ baseWidth,
1235+
/*base_height*/ baseHeight,
1236+
/*base_pitch*/ pitch,
1237+
/*x*/ b.i32_val(0),
1238+
/*y*/ offsetY,
1239+
/*elem_size_in_bits*/ elemSizeInBits,
1240+
/*tile_width*/ tileWidth,
1241+
/*tile_height*/ tileHeight,
1242+
/*v_blocks*/ vBlocks,
1243+
/*transpose*/ false,
1244+
/*vnni_transform*/
1245+
(usePackedType && opIdx == DpasEncodingAttr::OpIdx::OperandB &&
1246+
!isTransposeRequired && originalElemBits != 32));
1247+
12151248
if (others.size()) {
1249+
assert(masks.size() == others.size() &&
1250+
"The mask value has to be provided when "
1251+
"the other value is provided.");
12161252
VectorType vecTy =
12171253
vec_ty(eltTy, numValuesPerLoad * packedElemsNum);
12181254

@@ -1241,49 +1277,8 @@ struct LoadOpToBlockIOConversion
12411277
}
12421278
}
12431279
}
1244-
1245-
other_ = b.bitcast(v, load2DGenXType);
1246-
1247-
} else {
1248-
other_ = rewriter.create<LLVM::ConstantOp>(
1249-
loc, load2DGenXType, rewriter.getZeroAttr(load2DGenXType));
1250-
}
1251-
1252-
auto createLoadInstruction = [&]() -> SmallVector<Value, 1> {
1253-
// Use the top-left address of the block to load the data.
1254-
Value addrElem = b.bitcast(ptrs[{offsetM, offsetN}],
1255-
ptr_ty(ctx, 1 /*global*/));
1256-
addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0);
1257-
1258-
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
1259-
loc, load2DGenXType,
1260-
/*ptr*/ addrElem,
1261-
/*base_width*/ baseWidth,
1262-
/*base_height*/ baseHeight,
1263-
/*base_pitch*/ pitch,
1264-
/*x*/ b.i32_val(0),
1265-
/*y*/ b.i32_val(0),
1266-
/*elem_size_in_bits*/ elemSizeInBits,
1267-
/*tile_width*/ tileWidth,
1268-
/*tile_height*/ tileHeight,
1269-
/*v_blocks*/ vBlocks,
1270-
/*transpose*/ false,
1271-
/*vnni_transform*/
1272-
(usePackedType &&
1273-
opIdx == DpasEncodingAttr::OpIdx::OperandB &&
1274-
!isTransposeRequired && originalElemBits != 32));
1275-
return {load2dOp};
1276-
};
1277-
1278-
Value ret;
1279-
// Create a predicated load operation.
1280-
if (llMask) {
1281-
Block &endBlock = LLVM::intel::createPredicatedBlock(
1282-
rewriter, loc, pred, SmallVector<Value, 1>{other_},
1283-
createLoadInstruction);
1284-
ret = *endBlock.args_begin();
1285-
} else {
1286-
ret = createLoadInstruction()[0];
1280+
Value others = b.bitcast(v, load2DGenXType);
1281+
ret = b.select(pred, ret, others);
12871282
}
12881283

12891284
unsigned numOperandsM = opIdx != DpasEncodingAttr::OpIdx::OperandB

0 commit comments

Comments
 (0)