Skip to content

Commit 20be069

Browse files
authored
[intel] improve pitch and width constexpr folding (#5489)
This PR improves constant expression folding for pitch and width parameters in Intel GPU block I/O operations. The changes introduce a more robust constant evaluation mechanism that handles multiple levels of type casts and operation folding, addressing issue #5338.
1 parent c860a38 commit 20be069

File tree

4 files changed

+56
-31
lines changed

4 files changed

+56
-31
lines changed

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
136136
// CHECK: %[[OFFSET_0:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
137137
// CHECK: %[[OFFSET_1:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
138138
// CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
139-
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
139+
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[VAL_11]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
140140
// CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[VAL_12]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
141141
// CHECK: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
142142
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
@@ -199,7 +199,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
199199
// CHECK: %[[OFFSET_0:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
200200
// CHECK: %[[OFFSET_1:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
201201
// CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
202-
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
202+
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[VAL_10]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
203203
// CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[VAL_11]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
204204
// CHECK: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
205205
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>

third_party/intel/include/Utils/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ Value findOrCreateIntConstant(Location loc, int val, unsigned bitWidth,
1919
std::optional<mlir::triton::MakeTensorPtrOp>
2020
findDefiningMakeTensorPtrOp(Value val);
2121

22-
// This function folds the `op` operation and returns the constant value if it
22+
// This function folds the `v` value and returns the constant value if it
2323
// has successfully folded to a constant. Otherwise, it returns `std::nullopt`.
24-
std::optional<int64_t> getFoldedConstantValue(Operation *op);
24+
std::optional<int64_t> getFoldedConstantValue(Value v, int depth = 8);
2525

2626
// Return true if the `val` value is a constant containing a value equal to
2727
// expected.

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,23 +1629,20 @@ struct LoadOpToBlockIOConversion
16291629
std::swap(baseWidth, baseHeight);
16301630
}
16311631
// HW requires the pitch to be at least 64 bytes.
1632-
std::function<Value(Value)> skipTrunc = [&](Value v) {
1633-
if (dyn_cast_or_null<LLVM::TruncOp>(v.getDefiningOp()))
1634-
return skipTrunc(v.getDefiningOp()->getOperand(0));
1635-
return v;
1636-
};
1637-
if (Operation *op = skipTrunc(pitch).getDefiningOp()) {
1638-
std::optional<int64_t> pitchConst =
1639-
mlir::triton::intel::getFoldedConstantValue(op);
1640-
if (pitchConst.has_value()) {
1641-
if ((*pitchConst * elemSizeInBits / 8) < 64)
1642-
return failure();
1643-
}
1632+
if (auto pitchConst = mlir::triton::intel::getFoldedConstantValue(pitch)) {
1633+
if ((*pitchConst * elemSizeInBits / 8) < 64)
1634+
return failure();
16441635
}
16451636

16461637
baseWidth = b.trunc(i32_ty, baseWidth);
16471638
baseHeight = b.trunc(i32_ty, baseHeight);
16481639

1640+
if (auto widthConst =
1641+
mlir::triton::intel::getFoldedConstantValue(baseWidth)) {
1642+
if ((*widthConst * elemSizeInBits / 8) < 64)
1643+
return failure();
1644+
}
1645+
16491646
const unsigned originalElemBits = elemSizeInBits;
16501647
if (isTransposeRequired) {
16511648
// adjust the block io parameter to align HW's limitations on

third_party/intel/lib/Utils/Utility.cpp

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "intel/include/Utils/Utility.h"
22
#include "mlir/Dialect/Arith/IR/Arith.h"
3+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
34
#include "mlir/Dialect/SCF/IR/SCF.h"
45
#include "mlir/Dialect/UB/IR/UBOps.h"
56
#include "mlir/Interfaces/LoopLikeInterface.h"
@@ -104,28 +105,55 @@ std::optional<tt::MakeTensorPtrOp> findDefiningMakeTensorPtrOp(Value val) {
104105
return std::nullopt;
105106
}
106107

107-
std::optional<int64_t> getFoldedConstantValue(Operation *op) {
108-
SmallVector<OpFoldResult> results;
109-
if (failed(op->fold(results)))
110-
return std::nullopt;
108+
static Value skipCasts(Value v) {
109+
Operation *def = v.getDefiningOp();
110+
if (def &&
111+
isa<LLVM::TruncOp, LLVM::SExtOp, LLVM::ZExtOp, LLVM::BitcastOp>(def))
112+
return def->getOperand(0);
113+
return v;
114+
}
115+
116+
static Value foldValue(Value v) {
117+
if (Operation *def = v.getDefiningOp()) {
118+
SmallVector<OpFoldResult> results;
119+
120+
if (failed(def->fold(results)))
121+
return v;
111122

112-
// If fold succeeded but `results` is empty, we give a second try, after the
113-
// operands have been switched during the first call to `fold()`.
114-
if (results.empty()) {
115-
if (failed(op->fold(results)))
116-
return std::nullopt;
123+
// If fold succeeded but `results` is empty, we give a second try, after the
124+
// operands have been switched during the first call to `fold()`.
125+
if (results.empty()) {
126+
if (failed(def->fold(results)))
127+
return v;
128+
}
129+
130+
if (results.size() == 1) {
131+
if (auto val = dyn_cast_or_null<Value>(results[0]))
132+
return val;
133+
}
117134
}
135+
return v;
136+
}
118137

119-
if (results.size() != 1)
120-
return std::nullopt;
138+
std::optional<int64_t> getFoldedConstantValue(Value v, int depth) {
139+
for (int i = 0; i < depth; ++i) {
140+
if (auto res = getConstantIntValue(v))
141+
return res;
142+
143+
Value newV = skipCasts(v);
144+
newV = foldValue(newV);
145+
146+
if (newV == v)
147+
break;
121148

122-
return getConstantIntValue(results[0]);
149+
v = newV;
150+
}
151+
152+
return std::nullopt;
123153
}
124154

125155
bool isConstant(Value val, int64_t expected) {
126-
if (auto defOp = val.getDefiningOp())
127-
return (getFoldedConstantValue(defOp) == expected);
128-
return false;
156+
return (getFoldedConstantValue(val) == expected);
129157
}
130158

131159
Value getFinalValue(Value value) {

0 commit comments

Comments
 (0)