Skip to content

Commit 7f33969

Browse files
committed
[Helion]: Remove boundaryCheks on load operation using a block ptr/tensor descriptor
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 8e84b4e commit 7f33969

File tree

6 files changed

+231
-0
lines changed

6 files changed

+231
-0
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9595
mlir::test::registerTestTritonAMDGPURangeAnalysis();
9696
mlir::triton::registerConvertTritonToTritonGPUPass();
9797
mlir::triton::intel::registerTritonIntelFuseReshape();
98+
mlir::triton::intel::registerTritonIntelRemoveBoundaryChecks();
9899
mlir::triton::intel::registerTritonIntelRemoveMasks();
99100
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
100101
mlir::triton::registerRelayoutTritonGPUPass();

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def make_ttir(mod, metadata, opt):
200200
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
201201
passes.common.add_cse(pm)
202202
passes.common.add_licm(pm)
203+
intel.passes.ttir.add_remove_boundary_checks(pm)
203204
intel.passes.ttir.add_remove_masks(pm)
204205
intel.passes.ttir.add_fuse_reshape(pm)
205206
passes.common.add_canonicalizer(pm)

third_party/intel/include/Dialect/Triton/Transforms/Passes.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,41 @@ def TritonIntelFuseReshape
7070
];
7171
}
7272

73+
def TritonIntelRemoveBoundaryChecks
74+
: Pass<"triton-intel-remove-boundary-checks", "mlir::ModuleOp"> {
75+
let summary = "Remove unnecessary boundary checks from load operations (block pointers only)";
76+
77+
let description = [{
78+
This pass attempts to remove boundary checks that aren't necessary in a tt.load operation.
79+
For example, given:
80+
%lb = arith.bitcast %c0_i32 : i32 to i32
81+
%ub = arith.bitcast %c1024_i32 : i32 to i32
82+
%st = arith.bitcast %c64_i32 : i32 to i32
83+
scf.for %i = %lb to %ub step %st : i32 {
84+
%s0 = arith.constant 512 : i64
85+
%s1 = arith.constant 64 : i64
86+
%s2 = arith.constant 1024 : i64
87+
%a = arith.constant 65536 : i64
88+
%b = arith.constant 1 : i64
89+
%b = arith.constant 64 : i64
90+
%y = arith.constant 0 : i32
91+
%ptr = tt.make_tensor_ptr %base, [%s0, %s1, %s2], [%a, %b, %c], [%x, %y, %iv]
92+
{order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
93+
%load = tt.load %ptr {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x512x64xf16>>
94+
...
95+
// here %ptr is never updated.
96+
}
97+
98+
The transformation would drop the boundary check on the load operation because:
99+
- `%ptr` is never advanced in the loop
100+
- `%iv` has values [0, 64, 128, ..., 960], max(%iv) = 960
101+
- `%s2` is qual to 1014
102+
- the boundary check expression `%iv` < `%s2` is always true
103+
}];
104+
105+
let dependentDialects = [
106+
"mlir::triton::TritonDialect"
107+
];
108+
}
109+
73110
#endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES

third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_triton_library(TritonIntelTransforms
22
FuseReshape.cpp
3+
RemoveBoundaryChecks.cpp
34
RemoveMasks.cpp
45
TensorDescToBlockPointer.cpp
56

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
2+
#include "intel/include/Dialect/Triton/Transforms/Passes.h"
3+
#include "intel/include/Utils/Utility.h"
4+
#include "mlir/Dialect/Arith/IR/Arith.h"
5+
#include "mlir/Dialect/SCF/IR/SCF.h"
6+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
7+
#include "mlir/IR/Verifier.h"
8+
#include "mlir/Interfaces/InferIntRangeInterface.h"
9+
#include "mlir/Support/WalkResult.h"
10+
#include "triton/Dialect/Triton/IR/Dialect.h"
11+
#include "llvm/ADT/APInt.h"
12+
#include "llvm/Support/Debug.h"
13+
#include "llvm/Support/raw_ostream.h"
14+
#include <cmath>
15+
#include <optional>
16+
17+
#define DEBUG_TYPE "triton-intel-remove-boundary-checks"
18+
19+
using namespace mlir;
20+
namespace tt = mlir::triton;
21+
22+
namespace mlir::triton::intel {
23+
#define GEN_PASS_DEF_TRITONINTELREMOVEBOUNDARYCHECKS
24+
#include "intel/include/Dialect/Triton/Transforms/Passes.h.inc"
25+
} // namespace mlir::triton::intel
26+
27+
namespace {
28+
class BoundaryChecksRemover {
29+
public:
30+
void run(ModuleOp moduleOp) {
31+
moduleOp.walk([&](tt::LoadOp loadOp) {
32+
if (!isCandidate(loadOp))
33+
return WalkResult::skip();
34+
35+
tt::MakeTensorPtrOp makeTensorPtrOp =
36+
*tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr());
37+
LLVM_DEBUG(llvm::dbgs()
38+
<< "Analyzing boundaryCheck for: " << loadOp << "\n");
39+
40+
SmallVector<int> newBoundaryCheck;
41+
for (int boundIdx : loadOp.getBoundaryCheck()) {
42+
ArrayRef<int> order = makeTensorPtrOp.getOrder();
43+
int idx = order.size() - order[boundIdx] - 1;
44+
Value offset = makeTensorPtrOp.getOffsets()[idx];
45+
Value shape = makeTensorPtrOp.getShape()[idx];
46+
std::optional<int64_t> offsetVal = getConstantIntValue(offset),
47+
shapeVal = getConstantIntValue(shape);
48+
49+
// If the shape is not known at compile time we cannot determine whether
50+
// the bound check is unnecessary.
51+
if (!shapeVal) {
52+
LLVM_DEBUG(llvm::dbgs().indent(2)
53+
<< "Check at index " << boundIdx << " is necessary\n");
54+
newBoundaryCheck.push_back(boundIdx);
55+
continue;
56+
}
57+
58+
// Case 1: offset and shape are constant.
59+
if (offsetVal && *offsetVal < *shapeVal) {
60+
LLVM_DEBUG(llvm::dbgs().indent(2)
61+
<< "Check at index " << boundIdx << " is unnecessary\n");
62+
continue;
63+
}
64+
65+
// Case 2: analyze boundary check in loops.
66+
if (auto forOp = makeTensorPtrOp->getParentOfType<scf::ForOp>()) {
67+
assert(forOp.getSingleInductionVar() && "Single IV expected");
68+
Value iv = *forOp.getSingleInductionVar();
69+
if (offset != iv) {
70+
LLVM_DEBUG(llvm::dbgs().indent(2)
71+
<< "Check at index " << boundIdx << " is necessary\n");
72+
newBoundaryCheck.push_back(boundIdx);
73+
continue;
74+
}
75+
76+
OpFoldResult lb = *forOp.getSingleLowerBound();
77+
OpFoldResult ub = *forOp.getSingleUpperBound();
78+
OpFoldResult step = *forOp.getSingleStep();
79+
80+
auto computeLoopIVRange =
81+
[&](OpFoldResult lb, OpFoldResult ub,
82+
OpFoldResult step) -> std::optional<ConstantIntRanges> {
83+
auto getBoundValue =
84+
[](OpFoldResult bound) -> std::optional<int64_t> {
85+
if (std::optional<int64_t> opVal = getConstantIntValue(bound))
86+
return *opVal;
87+
88+
Value val = tt::intel::getFinalValue(cast<Value>(bound));
89+
if (auto cst = dyn_cast<arith::BitcastOp>(val.getDefiningOp()))
90+
val = cst.getIn();
91+
92+
return getConstantIntValue(getAsOpFoldResult(val));
93+
};
94+
95+
auto areLoopBoundKnown = [&](OpFoldResult lb, OpFoldResult ub,
96+
OpFoldResult step) {
97+
return (getBoundValue(lb) && getBoundValue(ub) &&
98+
getBoundValue(step));
99+
};
100+
101+
if (!areLoopBoundKnown(lb, ub, step))
102+
return std::nullopt;
103+
104+
int64_t lbVal = *getBoundValue(lb);
105+
int64_t ubVal = *getBoundValue(ub);
106+
int64_t stepVal = *getBoundValue(step);
107+
int64_t lastIVVal =
108+
lbVal + ((ubVal - lbVal - 1) / stepVal) * stepVal;
109+
llvm::APInt start(64, lbVal, true);
110+
llvm::APInt end(64, lastIVVal, true);
111+
112+
return ConstantIntRanges::range(start, end, true);
113+
};
114+
115+
std::optional<ConstantIntRanges> optRange =
116+
computeLoopIVRange(lb, ub, step);
117+
if (!optRange) {
118+
LLVM_DEBUG(llvm::dbgs().indent(2)
119+
<< "Check at index " << boundIdx << " is necessary\n");
120+
newBoundaryCheck.push_back(boundIdx);
121+
continue;
122+
}
123+
124+
// Compare the max value of the loop IV to the offset.
125+
APInt max = (*optRange).smax();
126+
if (max.getSExtValue() < shapeVal) {
127+
LLVM_DEBUG(llvm::dbgs().indent(2)
128+
<< "Check at index " << boundIdx << " is unnecessary\n");
129+
continue;
130+
}
131+
}
132+
133+
LLVM_DEBUG(llvm::dbgs().indent(2)
134+
<< "Check at index " << boundIdx << " is necessary\n");
135+
newBoundaryCheck.push_back(boundIdx);
136+
}
137+
138+
if (newBoundaryCheck.size() != loadOp.getBoundaryCheck().size()) {
139+
loadOp.setBoundaryCheck(newBoundaryCheck);
140+
LLVM_DEBUG(llvm::dbgs().indent(2)
141+
<< "Rewritten load is: " << loadOp << "\n");
142+
}
143+
144+
return WalkResult::advance();
145+
});
146+
}
147+
148+
private:
149+
// A candidate load operation is one that:
150+
// - has the boundary check attribute
151+
// - uses a block pointer defined by a `make_tensor_ptr` that is not
152+
// advanced
153+
bool isCandidate(tt::LoadOp loadOp) const {
154+
assert(loadOp && "Expecting a valid load operation");
155+
156+
ArrayRef<int> boundaryCheck = loadOp.getBoundaryCheck();
157+
if (boundaryCheck.empty())
158+
return false;
159+
160+
Type ptrType = loadOp.getPtr().getType();
161+
if (!tt::isTensorPointerType(ptrType))
162+
return false;
163+
164+
std::optional<tt::MakeTensorPtrOp> makeTensorPtrOp =
165+
tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr());
166+
if (!makeTensorPtrOp)
167+
return false;
168+
169+
if (llvm::any_of((*makeTensorPtrOp)->getUsers(),
170+
[](Operation *user) { return isa<tt::AdvanceOp>(user); }))
171+
return false;
172+
173+
return true;
174+
}
175+
};
176+
177+
} // namespace
178+
179+
struct TritonIntelRemoveBoundaryChecks
180+
: tt::intel::impl::TritonIntelRemoveBoundaryChecksBase<
181+
TritonIntelRemoveBoundaryChecks> {
182+
public:
183+
void runOnOperation() final {
184+
ModuleOp moduleOp = getOperation();
185+
BoundaryChecksRemover remover;
186+
remover.run(moduleOp);
187+
assert(succeeded(verify(moduleOp)) && "Module verification failed");
188+
}
189+
};

third_party/intel/triton_xpu.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ static uint32_t findKernels(llvm::Module &M,
5858
void init_triton_intel_passes_ttir(py::module &&m) {
5959
ADD_PASS_WRAPPER_0("add_convert_tdesc_to_block_pointer",
6060
intel::createTritonIntelTensorDescToBlockPointer);
61+
ADD_PASS_WRAPPER_0("add_remove_boundary_checks",
62+
intel::createTritonIntelRemoveBoundaryChecks);
6163
ADD_PASS_WRAPPER_0("add_remove_masks", intel::createTritonIntelRemoveMasks);
6264
ADD_PASS_WRAPPER_0("add_fuse_reshape", intel::createTritonIntelFuseReshape);
6365
}

0 commit comments

Comments
 (0)