Skip to content

Commit 8e26be5

Browse files
[LoadOpToBlockIOConversion] Improve codegen for other (#5141)
Optimizes block load lowering when the optional "other" value is a non-zero constant splat by materializing a single repeated constant instead of unpacking LLVM elements, and updates tests accordingly by removing prior select expectations. Key changes focus on specializing constant handling and pruning test expectations tied to the old expansion path. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent bd6de9e commit 8e26be5

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
266266
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
267267
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
268268
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
269-
// CHECK: llvm.select {{.*}}, %[[LOAD_0]], {{.*}} : i1, vector<32xf16>
270269

271270
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
272271
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -275,7 +274,6 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
275274
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
276275
// CHECK: %[[BASE_Y_1:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
277276
// CHECK: %[[LOAD_1:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_1]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
278-
// CHECK: llvm.select {{.*}}, %[[LOAD_1]], {{.*}} : i1, vector<32xf16>
279277

280278
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
281279
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -284,7 +282,6 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
284282
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
285283
// CHECK: %[[BASE_Y_2:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
286284
// CHECK: %[[LOAD_2:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_2]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
287-
// CHECK: llvm.select {{.*}}, %[[LOAD_2]], {{.*}} : i1, vector<32xf16>
288285

289286
// CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
290287
// CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -293,7 +290,6 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
293290
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
294291
// CHECK: %[[BASE_Y_3:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32
295292
// CHECK: %[[LOAD_3:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_3]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
296-
// CHECK: llvm.select {{.*}}, %[[LOAD_3]], {{.*}} : i1, vector<32xf16>
297293
%11 = tt.load %10, %a_mask, %a_other {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma>
298294

299295
tt.return

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,7 +1863,6 @@ struct LoadOpToBlockIOConversion
18631863
return rewriteTensorPointerLoad(op, adaptor, rewriter);
18641864

18651865
Value mask = op.getMask();
1866-
Value other = op.getOther();
18671866
Type resultType = op.getType();
18681867
auto tensorType = cast<RankedTensorType>(resultType);
18691868

@@ -2056,16 +2055,12 @@ struct LoadOpToBlockIOConversion
20562055
unsigned instWidth = dpasInstShape[threadOrder[rank - 2]];
20572056
unsigned instHeight = dpasInstShape[threadOrder[rank - 1]];
20582057

2059-
bool otherIsSplatConstInt = false;
2060-
int64_t splatVal = 0;
2061-
20622058
std::map<SmallVector<unsigned>, Value> ptrs;
20632059
std::map<SmallVector<unsigned>, Value> masks;
20642060
std::map<SmallVector<unsigned>, Value> others;
20652061

20662062
Value llPtr = adaptor.getPtr();
20672063
Value llMask = adaptor.getMask();
2068-
Value llOther = adaptor.getOther();
20692064

20702065
SmallVector<Value> ptrElems, maskElems, otherElems;
20712066
// Get the LLVM values for pointers
@@ -2101,16 +2096,30 @@ struct LoadOpToBlockIOConversion
21012096
return failure();
21022097

21032098
// Get the LLVM values for `other`
2099+
Value other = op.getOther();
2100+
Value llOther = adaptor.getOther();
21042101
DenseElementsAttr constAttr;
2105-
if (other && isa<IntegerType>(eltTy) &&
2106-
matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() &&
2107-
isa<IntegerType>(constAttr.getElementType())) {
2108-
otherIsSplatConstInt = true;
2109-
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
2110-
}
2111-
if (other) {
2112-
otherElems = unpackLLElements(loc, llOther, rewriter);
2113-
}
2102+
if (other)
2103+
if (matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) {
2104+
Type elemTy = constAttr.getElementType();
2105+
auto handleSplatValue = [&](auto splatVal) {
2106+
if (!splatVal.isZero()) {
2107+
otherElems = SmallVector<Value>(
2108+
numElems,
2109+
rewriter.create<LLVM::ConstantOp>(loc, elemTy, splatVal));
2110+
}
2111+
};
2112+
2113+
TypeSwitch<mlir::Type>(elemTy)
2114+
.Case<FloatType>([&](FloatType) {
2115+
handleSplatValue(constAttr.getSplatValue<APFloat>());
2116+
})
2117+
.Case<IntegerType>([&](IntegerType) {
2118+
handleSplatValue(constAttr.getSplatValue<APInt>());
2119+
});
2120+
} else {
2121+
otherElems = unpackLLElements(loc, llOther, rewriter);
2122+
}
21142123

21152124
// re-arrange the ptrs and masks to for large 2D block IO.
21162125
// Layout is unrelated to the scalar type.

0 commit comments

Comments
 (0)