33#include " mlir/Dialect/Arith/IR/Arith.h"
44#include " mlir/Dialect/SCF/IR/SCF.h"
55#include " mlir/Dialect/Utils/StaticValueUtils.h"
6+ #include " mlir/IR/BuiltinTypes.h"
67#include " mlir/IR/Verifier.h"
78#include " mlir/Interfaces/InferIntRangeInterface.h"
89#include " mlir/Support/WalkResult.h"
@@ -42,6 +43,8 @@ class BoundaryChecksRemover {
4243 int idx = order.size () - order[boundIdx] - 1 ;
4344 Value offset = makeTensorPtrOp.getOffsets ()[idx];
4445 Value shape = makeTensorPtrOp.getShape ()[idx];
46+ auto resType = cast<RankedTensorType>(loadOp.getResult ().getType ());
47+ ArrayRef<int64_t > resShape = resType.getShape ();
4548 std::optional<int64_t > offsetVal = getConstantIntValue (offset),
4649 shapeVal = getConstantIntValue (shape);
4750
@@ -55,7 +58,7 @@ class BoundaryChecksRemover {
5558 }
5659
5760 // Case 1: offset and shape are constant.
58- if (offsetVal && *offsetVal < *shapeVal) {
61+ if (offsetVal && (( *offsetVal + resShape[idx]) <= *shapeVal) ) {
5962 LLVM_DEBUG (llvm::dbgs ().indent (2 )
6063 << " Check at index " << boundIdx << " is unnecessary\n " );
6164 continue ;
@@ -120,9 +123,8 @@ class BoundaryChecksRemover {
120123 continue ;
121124 }
122125
123- // Compare the max value of the loop IV to the offset.
124- APInt max = (*optRange).smax ();
125- if (max.getSExtValue () < shapeVal) {
126+ APInt maxIV = (*optRange).smax ();
127+ if (maxIV.getSExtValue () + resShape[idx] <= shapeVal) {
126128 LLVM_DEBUG (llvm::dbgs ().indent (2 )
127129 << " Check at index " << boundIdx << " is unnecessary\n " );
128130 continue ;
0 commit comments