Skip to content

Commit 8e5c989

Browse files
authored
[Helion]: Remove boundaryChecks on load operation using a block ptr/te… (#5363)
We have implemented feature #5272 to collapse 3-dim loads on block ptrs into 2-dim loads when the tensor loaded has outermost dimension equal to 1. This feature compliments feature #5272. The goal here is to remove unnecessary boundaryCheck indexes on load operations if they are not necessary. --------- Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 9e23713 commit 8e5c989

File tree

7 files changed

+290
-0
lines changed

7 files changed

+290
-0
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9595
mlir::test::registerTestTritonAMDGPURangeAnalysis();
9696
mlir::triton::registerConvertTritonToTritonGPUPass();
9797
mlir::triton::intel::registerTritonIntelFuseReshape();
98+
mlir::triton::intel::registerTritonIntelRemoveBoundaryChecks();
9899
mlir::triton::intel::registerTritonIntelRemoveMasks();
99100
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
100101
mlir::triton::registerRelayoutTritonGPUPass();
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: triton-opt %s -split-input-file -triton-intel-remove-boundary-checks | FileCheck %s
2+
3+
module {
4+
tt.func public @simple_load(%load_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %store_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
5+
%c1_i64 = arith.constant 1 : i64
6+
%c64_i64 = arith.constant 64 : i64
7+
%c512_i64 = arith.constant 512 : i64
8+
%c1024_i64 = arith.constant 1024 : i64
9+
%c0_i32 = arith.constant 0 : i32
10+
%x = arith.constant 10 : i32
11+
%in = tt.make_tensor_ptr %load_ptr, [%c1_i64, %c64_i64, %c1024_i64], [%c512_i64, %c64_i64, %c1_i64], [%c0_i32, %c0_i32, %x] {order = array<i32: 2, 1, 0>} : <tensor<1x64x64xf16>>
12+
// boundaryCheck is unnecessary because %x + loadResType.shape[2] - 1 = 10 + 64 - 1 = 73 < 1024
13+
%load = tt.load %in {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x64x64xf16>>
14+
tt.return
15+
}
16+
// CHECK-LABEL: simple_load
17+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr
18+
// CHECK: tt.load [[PTR]] : !tt.ptr<tensor<1x64x64xf16>>
19+
}
20+
21+
// -----
22+
23+
module {
24+
tt.func public @load_in_for_loop(%load_ptr0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %load_ptr1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %store_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
25+
%c0_i32 = arith.constant 0 : i32
26+
%c1_i32 = arith.constant 1 : i32
27+
%c20_i32 = arith.constant 20 : i32
28+
%c64_i32 = arith.constant 64 : i32
29+
%c1024_i32 = arith.constant 1024 : i32
30+
scf.for %x = %c0_i32 to %c20_i32 step %c1_i32 : i32 {
31+
%pid = tt.get_program_id x : i32
32+
%c0_i64 = arith.constant 0 : i64
33+
%c1_i64 = arith.constant 1 : i64
34+
%c512_i64 = arith.constant 512 : i64
35+
%c1024_i64 = arith.constant 1024 : i64
36+
%c64_i64 = arith.constant 64 : i64
37+
%c65536_i64 = arith.constant 65536 : i64
38+
%ptr0 = tt.make_tensor_ptr %load_ptr0, [%c512_i64, %c1024_i64, %c64_i64], [%c65536_i64, %c64_i64, %c1_i64], [%x, %pid, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
39+
%load0 = tt.load %ptr0 {boundaryCheck = array<i32: 1, 2>, padding = 1 : i32} : !tt.ptr<tensor<1x512x64xf16>>
40+
%9 = arith.bitcast %c0_i32 : i32 to i32
41+
%10 = arith.bitcast %c1024_i32 : i32 to i32
42+
%11 = arith.bitcast %c64_i32 : i32 to i32
43+
scf.for %z = %9 to %10 step %11 iter_args() -> () : i32 {
44+
%ptr1 = tt.make_tensor_ptr %load_ptr1, [%c512_i64, %c64_i64, %c1024_i64], [%c65536_i64, %c1_i64, %c64_i64], [%x, %c0_i32, %z] {order = array<i32: 2, 0, 1>} : <tensor<1x64x64xf16>>
45+
// a. boundaryCheck = 1 checks the block ptr offset at index 2 (%z)
46+
// b. boundaryCheck = 2 checks the block ptr offset at index 1 (%y)
47+
// Check (a) is unnecessary because max(%z) + loadResType.shape[2] - 1 = 960 + 64 - 1 = 1023, which is less than 1024.
48+
// Check (b) is unnecessary because max(0) + loadResType.shape[1] - 1 = 0 + 64 -1 = 63, which is less than 64.
49+
%load1 = tt.load %ptr1 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x64x64xf16>>
50+
}
51+
}
52+
tt.return
53+
}
54+
// CHECK-LABEL: load_in_for_loop
55+
// CHECK-COUNT-2: scf.for
56+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr
57+
// CHECK: tt.load [[PTR]] : !tt.ptr<tensor<1x64x64xf16>>
58+
}

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def make_ttir(mod, metadata, opt):
192192
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
193193
passes.common.add_cse(pm)
194194
passes.common.add_licm(pm)
195+
intel.passes.ttir.add_remove_boundary_checks(pm)
195196
intel.passes.ttir.add_remove_masks(pm)
196197
intel.passes.ttir.add_fuse_reshape(pm)
197198
passes.common.add_canonicalizer(pm)

third_party/intel/include/Dialect/Triton/Transforms/Passes.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,41 @@ def TritonIntelFuseReshape
7070
];
7171
}
7272

73+
def TritonIntelRemoveBoundaryChecks
74+
: Pass<"triton-intel-remove-boundary-checks", "mlir::ModuleOp"> {
75+
let summary = "Remove unnecessary boundary checks from load operations (block pointers only)";
76+
77+
let description = [{
78+
This pass attempts to remove boundary checks that aren't necessary in a tt.load operation.
79+
For example, given:
80+
%lb = arith.bitcast %c0_i32 : i32 to i32
81+
%ub = arith.bitcast %c1024_i32 : i32 to i32
82+
%st = arith.bitcast %c64_i32 : i32 to i32
83+
scf.for %iv = %lb to %ub step %st : i32 {
84+
%s0 = arith.constant 512 : i64
85+
%s1 = arith.constant 64 : i64
86+
%s2 = arith.constant 1024 : i64
87+
%a = arith.constant 65536 : i64
88+
%b = arith.constant 1 : i64
89+
%b = arith.constant 64 : i64
90+
%y = arith.constant 0 : i32
91+
%ptr = tt.make_tensor_ptr %base, [%s0, %s1, %s2], [%a, %b, %c], [%x, %y, %iv]
92+
{order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
93+
%load = tt.load %ptr {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x512x64xf16>>
94+
...
95+
// here %ptr is never updated.
96+
}
97+
98+
The transformation would drop the boundary check on the load operation because:
99+
- `%ptr` is never advanced in the loop
100+
- `%iv` has values [0, 64, 128, ..., 960], max(%iv) = 960
101+
- `%s2` is equal to 1024
102+
- the boundary check expression `max(%iv) + load_res.shape_in_dim -1` < `%s2` is true.
103+
}];
104+
105+
let dependentDialects = [
106+
"mlir::triton::TritonDialect"
107+
];
108+
}
109+
73110
#endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES

third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_triton_library(TritonIntelTransforms
22
FuseReshape.cpp
3+
RemoveBoundaryChecks.cpp
34
RemoveMasks.cpp
45
TensorDescToBlockPointer.cpp
56

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
#include "intel/include/Dialect/Triton/Transforms/Passes.h"
2+
#include "intel/include/Utils/Utility.h"
3+
#include "mlir/Dialect/Arith/IR/Arith.h"
4+
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
6+
#include "mlir/IR/BuiltinTypes.h"
7+
#include "mlir/IR/Verifier.h"
8+
#include "mlir/Interfaces/InferIntRangeInterface.h"
9+
#include "mlir/Support/WalkResult.h"
10+
#include "triton/Dialect/Triton/IR/Dialect.h"
11+
#include "llvm/ADT/APInt.h"
12+
#include "llvm/Support/Debug.h"
13+
#include "llvm/Support/raw_ostream.h"
14+
#include <cmath>
15+
#include <optional>
16+
17+
#define DEBUG_TYPE "triton-intel-remove-boundary-checks"
18+
19+
using namespace mlir;
20+
namespace tt = mlir::triton;
21+
22+
namespace mlir::triton::intel {
23+
#define GEN_PASS_DEF_TRITONINTELREMOVEBOUNDARYCHECKS
24+
#include "intel/include/Dialect/Triton/Transforms/Passes.h.inc"
25+
} // namespace mlir::triton::intel
26+
27+
namespace {
28+
class BoundaryChecksRemover {
29+
public:
30+
void run(ModuleOp moduleOp) {
31+
moduleOp.walk([&](tt::LoadOp loadOp) {
32+
if (!isCandidate(loadOp))
33+
return WalkResult::skip();
34+
35+
tt::MakeTensorPtrOp makeTensorPtrOp =
36+
*tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr());
37+
LLVM_DEBUG(llvm::dbgs()
38+
<< "Analyzing boundaryCheck for: " << loadOp << "\n");
39+
40+
SmallVector<int> newBoundaryCheck;
41+
for (int boundIdx : loadOp.getBoundaryCheck()) {
42+
ArrayRef<int> order = makeTensorPtrOp.getOrder();
43+
int idx = order.size() - order[boundIdx] - 1;
44+
Value offset = makeTensorPtrOp.getOffsets()[idx];
45+
Value shape = makeTensorPtrOp.getShape()[idx];
46+
auto resType = cast<RankedTensorType>(loadOp.getResult().getType());
47+
ArrayRef<int64_t> resShape = resType.getShape();
48+
std::optional<int64_t> offsetVal = getConstantIntValue(offset),
49+
shapeVal = getConstantIntValue(shape);
50+
51+
// If the shape is not known at compile time we cannot determine whether
52+
// the bound check is unnecessary.
53+
if (!shapeVal) {
54+
LLVM_DEBUG(llvm::dbgs().indent(2)
55+
<< "Check at index " << boundIdx << " is necessary\n");
56+
newBoundaryCheck.push_back(boundIdx);
57+
continue;
58+
}
59+
60+
// Case 1: offset and shape are constant.
61+
if (offsetVal && ((*offsetVal + resShape[idx]) <= *shapeVal)) {
62+
LLVM_DEBUG(llvm::dbgs().indent(2)
63+
<< "Check at index " << boundIdx << " is unnecessary\n");
64+
continue;
65+
}
66+
67+
// Case 2: analyze boundary check in loops.
68+
if (auto forOp = makeTensorPtrOp->getParentOfType<scf::ForOp>()) {
69+
assert(forOp.getSingleInductionVar() && "Single IV expected");
70+
Value iv = *forOp.getSingleInductionVar();
71+
if (offset != iv) {
72+
LLVM_DEBUG(llvm::dbgs().indent(2)
73+
<< "Check at index " << boundIdx << " is necessary\n");
74+
newBoundaryCheck.push_back(boundIdx);
75+
continue;
76+
}
77+
78+
OpFoldResult lb = *forOp.getSingleLowerBound();
79+
OpFoldResult ub = *forOp.getSingleUpperBound();
80+
OpFoldResult step = *forOp.getSingleStep();
81+
82+
auto computeLoopIVRange =
83+
[&](OpFoldResult lb, OpFoldResult ub,
84+
OpFoldResult step) -> std::optional<ConstantIntRanges> {
85+
auto getBoundValue =
86+
[](OpFoldResult bound) -> std::optional<int64_t> {
87+
if (std::optional<int64_t> opVal = getConstantIntValue(bound))
88+
return *opVal;
89+
90+
Value val = tt::intel::getFinalValue(cast<Value>(bound));
91+
if (auto cst = dyn_cast<arith::BitcastOp>(val.getDefiningOp()))
92+
val = cst.getIn();
93+
94+
return getConstantIntValue(getAsOpFoldResult(val));
95+
};
96+
97+
auto areLoopBoundKnown = [&](OpFoldResult lb, OpFoldResult ub,
98+
OpFoldResult step) {
99+
return (getBoundValue(lb) && getBoundValue(ub) &&
100+
getBoundValue(step));
101+
};
102+
103+
if (!areLoopBoundKnown(lb, ub, step))
104+
return std::nullopt;
105+
106+
int64_t lbVal = *getBoundValue(lb);
107+
int64_t ubVal = *getBoundValue(ub);
108+
int64_t stepVal = *getBoundValue(step);
109+
int64_t lastIVVal =
110+
lbVal + ((ubVal - lbVal - 1) / stepVal) * stepVal;
111+
llvm::APInt start(64, lbVal, true);
112+
llvm::APInt end(64, lastIVVal, true);
113+
114+
return ConstantIntRanges::range(start, end, true);
115+
};
116+
117+
std::optional<ConstantIntRanges> optRange =
118+
computeLoopIVRange(lb, ub, step);
119+
if (!optRange) {
120+
LLVM_DEBUG(llvm::dbgs().indent(2)
121+
<< "Check at index " << boundIdx << " is necessary\n");
122+
newBoundaryCheck.push_back(boundIdx);
123+
continue;
124+
}
125+
126+
APInt maxIV = (*optRange).smax();
127+
if (maxIV.getSExtValue() + resShape[idx] <= shapeVal) {
128+
LLVM_DEBUG(llvm::dbgs().indent(2)
129+
<< "Check at index " << boundIdx << " is unnecessary\n");
130+
continue;
131+
}
132+
}
133+
134+
LLVM_DEBUG(llvm::dbgs().indent(2)
135+
<< "Check at index " << boundIdx << " is necessary\n");
136+
newBoundaryCheck.push_back(boundIdx);
137+
}
138+
139+
if (newBoundaryCheck.size() != loadOp.getBoundaryCheck().size()) {
140+
loadOp.setBoundaryCheck(newBoundaryCheck);
141+
LLVM_DEBUG(llvm::dbgs().indent(2)
142+
<< "Rewritten load is: " << loadOp << "\n");
143+
}
144+
145+
return WalkResult::advance();
146+
});
147+
}
148+
149+
private:
150+
// A candidate load operation is one that:
151+
// - has the boundary check attribute
152+
// - uses a block pointer defined by a `make_tensor_ptr` that is not
153+
// advanced
154+
bool isCandidate(tt::LoadOp loadOp) const {
155+
assert(loadOp && "Expecting a valid load operation");
156+
157+
ArrayRef<int> boundaryCheck = loadOp.getBoundaryCheck();
158+
if (boundaryCheck.empty())
159+
return false;
160+
161+
Type ptrType = loadOp.getPtr().getType();
162+
if (!tt::isTensorPointerType(ptrType))
163+
return false;
164+
165+
std::optional<tt::MakeTensorPtrOp> makeTensorPtrOp =
166+
tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr());
167+
if (!makeTensorPtrOp)
168+
return false;
169+
170+
if (llvm::any_of((*makeTensorPtrOp)->getUsers(),
171+
[](Operation *user) { return isa<tt::AdvanceOp>(user); }))
172+
return false;
173+
174+
return true;
175+
}
176+
};
177+
178+
} // namespace
179+
180+
struct TritonIntelRemoveBoundaryChecks
181+
: tt::intel::impl::TritonIntelRemoveBoundaryChecksBase<
182+
TritonIntelRemoveBoundaryChecks> {
183+
public:
184+
void runOnOperation() final {
185+
ModuleOp moduleOp = getOperation();
186+
BoundaryChecksRemover remover;
187+
remover.run(moduleOp);
188+
assert(succeeded(verify(moduleOp)) && "Module verification failed");
189+
}
190+
};

third_party/intel/triton_xpu.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ static uint32_t findKernels(llvm::Module &M,
5858
void init_triton_intel_passes_ttir(py::module &&m) {
5959
ADD_PASS_WRAPPER_0("add_convert_tdesc_to_block_pointer",
6060
intel::createTritonIntelTensorDescToBlockPointer);
61+
ADD_PASS_WRAPPER_0("add_remove_boundary_checks",
62+
intel::createTritonIntelRemoveBoundaryChecks);
6163
ADD_PASS_WRAPPER_0("add_remove_masks", intel::createTritonIntelRemoveMasks);
6264
ADD_PASS_WRAPPER_0("add_fuse_reshape", intel::createTritonIntelFuseReshape);
6365
}

0 commit comments

Comments
 (0)