Skip to content

Commit ca49589

Browse files
committed
Address code review comments
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent b059f2c commit ca49589

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

third_party/intel/include/Dialect/Triton/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def TritonIntelRemoveBoundaryChecks
9898
The transformation would drop the boundary check on the load operation because:
9999
- `%ptr` is never advanced in the loop
100100
- `%iv` has values [0, 64, 128, ..., 960], max(%iv) = 960
101-
- `%s2` is equal to 1014
102-
- the boundary check expression `%iv` < `%s2` is always true
101+
- `%s2` is equal to 1024
102+
- the boundary check expression `max(%iv) + load_res.shape_in_dim -1` < `%s2` is true.
103103
}];
104104

105105
let dependentDialects = [

third_party/intel/lib/Dialect/Triton/Transforms/RemoveBoundaryChecks.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/Dialect/Arith/IR/Arith.h"
44
#include "mlir/Dialect/SCF/IR/SCF.h"
55
#include "mlir/Dialect/Utils/StaticValueUtils.h"
6+
#include "mlir/IR/BuiltinTypes.h"
67
#include "mlir/IR/Verifier.h"
78
#include "mlir/Interfaces/InferIntRangeInterface.h"
89
#include "mlir/Support/WalkResult.h"
@@ -42,6 +43,8 @@ class BoundaryChecksRemover {
4243
int idx = order.size() - order[boundIdx] - 1;
4344
Value offset = makeTensorPtrOp.getOffsets()[idx];
4445
Value shape = makeTensorPtrOp.getShape()[idx];
46+
auto resType = cast<RankedTensorType>(loadOp.getResult().getType());
47+
ArrayRef<int64_t> resShape = resType.getShape();
4548
std::optional<int64_t> offsetVal = getConstantIntValue(offset),
4649
shapeVal = getConstantIntValue(shape);
4750

@@ -55,7 +58,7 @@ class BoundaryChecksRemover {
5558
}
5659

5760
// Case 1: offset and shape are constant.
58-
if (offsetVal && *offsetVal < *shapeVal) {
61+
if (offsetVal && ((*offsetVal + resShape[idx]) <= *shapeVal)) {
5962
LLVM_DEBUG(llvm::dbgs().indent(2)
6063
<< "Check at index " << boundIdx << " is unnecessary\n");
6164
continue;
@@ -120,9 +123,8 @@ class BoundaryChecksRemover {
120123
continue;
121124
}
122125

123-
// Compare the max value of the loop IV to the offset.
124-
APInt max = (*optRange).smax();
125-
if (max.getSExtValue() < shapeVal) {
126+
APInt maxIV = (*optRange).smax();
127+
if (maxIV.getSExtValue() + resShape[idx] <= shapeVal) {
126128
LLVM_DEBUG(llvm::dbgs().indent(2)
127129
<< "Check at index " << boundIdx << " is unnecessary\n");
128130
continue;

0 commit comments

Comments
 (0)