Skip to content

Commit 5cd910c

Browse files
Only create predicated block when load with mask (#4535)
CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15744389236 16% improvement on GEMM tensor of pointer. ![Screenshot 2025-06-18 204424](https://github.com/user-attachments/assets/84a50381-d447-4464-aea1-50d3db6f78b4) Signed-off-by: Whitney Tsang <[email protected]>
1 parent 7c4300b commit 5cd910c

File tree

2 files changed

+36
-28
lines changed

2 files changed

+36
-28
lines changed

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
145145
%arg1: tensor<256x64x!tt.ptr<f16>, #mma_1>,
146146
%arg2: tensor<128x64x!tt.ptr<f16>, #mma_2>,
147147
%arg3: tensor<256x64x!tt.ptr<f16>, #mma_2>) {
148+
// CHECK-NOT: llvm.cond_br
148149
// CHECK-COUNT-4: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
149150
%0 = tt.load %arg0 {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma>
150151

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,35 +1240,42 @@ struct LoadOpToBlockIOConversion
12401240
loc, load2DGenXType, rewriter.getZeroAttr(load2DGenXType));
12411241
}
12421242

1243+
auto createLoadInstruction = [&]() -> SmallVector<Value, 1> {
1244+
// Use the top-left address of the block to load the data.
1245+
Value addrElem = b.bitcast(ptrs[{offsetM, offsetN}],
1246+
ptr_ty(ctx, 1 /*global*/));
1247+
addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0);
1248+
1249+
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
1250+
loc, load2DGenXType,
1251+
/*ptr*/ addrElem,
1252+
/*base_width*/ baseWidth,
1253+
/*base_height*/ baseHeight,
1254+
/*base_pitch*/ pitch,
1255+
/*x*/ b.i32_val(0),
1256+
/*y*/ b.i32_val(0),
1257+
/*elem_size_in_bits*/ elemSizeInBits,
1258+
/*tile_width*/ tileWidth,
1259+
/*tile_height*/ tileHeight,
1260+
/*v_blocks*/ vBlocks,
1261+
/*transpose*/ false,
1262+
/*vnni_transform*/
1263+
(usePackedType &&
1264+
opIdx == DpasEncodingAttr::OpIdx::OperandB &&
1265+
!isTransposeRequired && originalElemBits != 32));
1266+
return {load2dOp};
1267+
};
1268+
1269+
Value ret;
12431270
// Create a predicated load operation.
1244-
Block &endBlock = LLVM::intel::createPredicatedBlock(
1245-
rewriter, loc, pred, SmallVector<Value, 1>{other_}, [&]() {
1246-
// Use the top-left address of the block to load the data.
1247-
Value addrElem = b.bitcast(ptrs[{offsetM, offsetN}],
1248-
ptr_ty(ctx, 1 /*global*/));
1249-
addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0);
1250-
1251-
auto load2dOp =
1252-
rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
1253-
loc, load2DGenXType,
1254-
/*ptr*/ addrElem,
1255-
/*base_width*/ baseWidth,
1256-
/*base_height*/ baseHeight,
1257-
/*base_pitch*/ pitch,
1258-
/*x*/ b.i32_val(0),
1259-
/*y*/ b.i32_val(0),
1260-
/*elem_size_in_bits*/ elemSizeInBits,
1261-
/*tile_width*/ tileWidth,
1262-
/*tile_height*/ tileHeight,
1263-
/*v_blocks*/ vBlocks,
1264-
/*transpose*/ false,
1265-
/*vnni_transform*/
1266-
(usePackedType &&
1267-
opIdx == DpasEncodingAttr::OpIdx::OperandB &&
1268-
!isTransposeRequired && originalElemBits != 32));
1269-
return SmallVector<Value, 1>{load2dOp};
1270-
});
1271-
Value ret = *endBlock.args_begin();
1271+
if (llMask) {
1272+
Block &endBlock = LLVM::intel::createPredicatedBlock(
1273+
rewriter, loc, pred, SmallVector<Value, 1>{other_},
1274+
createLoadInstruction);
1275+
ret = *endBlock.args_begin();
1276+
} else {
1277+
ret = createLoadInstruction()[0];
1278+
}
12721279

12731280
unsigned numOperandsM = opIdx != DpasEncodingAttr::OpIdx::OperandB
12741281
? numOperandsOuterDimPerLoad

0 commit comments

Comments
 (0)