-
Couldn't load subscription status.
- Fork 75
[Helion]: Remove boundaryChecks on load operation using a block ptr/te… #5363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
etiotto
wants to merge
4
commits into
main
Choose a base branch
from
etiotto.remove_boundary_checks
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
189 changes: 189 additions & 0 deletions
189
third_party/intel/lib/Dialect/Triton/Transforms/RemoveBoundaryChecks.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,189 @@ | ||
|
|
||
etiotto marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #include "intel/include/Dialect/Triton/Transforms/Passes.h" | ||
| #include "intel/include/Utils/Utility.h" | ||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Dialect/SCF/IR/SCF.h" | ||
| #include "mlir/Dialect/Utils/StaticValueUtils.h" | ||
| #include "mlir/IR/Verifier.h" | ||
| #include "mlir/Interfaces/InferIntRangeInterface.h" | ||
| #include "mlir/Support/WalkResult.h" | ||
| #include "triton/Dialect/Triton/IR/Dialect.h" | ||
| #include "llvm/ADT/APInt.h" | ||
| #include "llvm/Support/Debug.h" | ||
| #include "llvm/Support/raw_ostream.h" | ||
| #include <cmath> | ||
| #include <optional> | ||
|
|
||
| #define DEBUG_TYPE "triton-intel-remove-boundary-checks" | ||
|
|
||
| using namespace mlir; | ||
| namespace tt = mlir::triton; | ||
|
|
||
| namespace mlir::triton::intel { | ||
| #define GEN_PASS_DEF_TRITONINTELREMOVEBOUNDARYCHECKS | ||
| #include "intel/include/Dialect/Triton/Transforms/Passes.h.inc" | ||
| } // namespace mlir::triton::intel | ||
|
|
||
| namespace { | ||
| class BoundaryChecksRemover { | ||
| public: | ||
| void run(ModuleOp moduleOp) { | ||
| moduleOp.walk([&](tt::LoadOp loadOp) { | ||
| if (!isCandidate(loadOp)) | ||
| return WalkResult::skip(); | ||
|
|
||
| tt::MakeTensorPtrOp makeTensorPtrOp = | ||
| *tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); | ||
| LLVM_DEBUG(llvm::dbgs() | ||
| << "Analyzing boundaryCheck for: " << loadOp << "\n"); | ||
|
|
||
| SmallVector<int> newBoundaryCheck; | ||
| for (int boundIdx : loadOp.getBoundaryCheck()) { | ||
| ArrayRef<int> order = makeTensorPtrOp.getOrder(); | ||
| int idx = order.size() - order[boundIdx] - 1; | ||
| Value offset = makeTensorPtrOp.getOffsets()[idx]; | ||
| Value shape = makeTensorPtrOp.getShape()[idx]; | ||
| std::optional<int64_t> offsetVal = getConstantIntValue(offset), | ||
| shapeVal = getConstantIntValue(shape); | ||
|
|
||
| // If the shape is not known at compile time we cannot determine whether | ||
| // the bound check is unnecessary. | ||
| if (!shapeVal) { | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is necessary\n"); | ||
| newBoundaryCheck.push_back(boundIdx); | ||
| continue; | ||
| } | ||
|
|
||
| // Case 1: offset and shape are constant. | ||
| if (offsetVal && *offsetVal < *shapeVal) { | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is unnecessary\n"); | ||
| continue; | ||
| } | ||
|
|
||
| // Case 2: analyze boundary check in loops. | ||
| if (auto forOp = makeTensorPtrOp->getParentOfType<scf::ForOp>()) { | ||
| assert(forOp.getSingleInductionVar() && "Single IV expected"); | ||
| Value iv = *forOp.getSingleInductionVar(); | ||
| if (offset != iv) { | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is necessary\n"); | ||
| newBoundaryCheck.push_back(boundIdx); | ||
| continue; | ||
| } | ||
|
|
||
| OpFoldResult lb = *forOp.getSingleLowerBound(); | ||
| OpFoldResult ub = *forOp.getSingleUpperBound(); | ||
| OpFoldResult step = *forOp.getSingleStep(); | ||
|
|
||
| auto computeLoopIVRange = | ||
| [&](OpFoldResult lb, OpFoldResult ub, | ||
| OpFoldResult step) -> std::optional<ConstantIntRanges> { | ||
| auto getBoundValue = | ||
| [](OpFoldResult bound) -> std::optional<int64_t> { | ||
| if (std::optional<int64_t> opVal = getConstantIntValue(bound)) | ||
| return *opVal; | ||
|
|
||
| Value val = tt::intel::getFinalValue(cast<Value>(bound)); | ||
| if (auto cst = dyn_cast<arith::BitcastOp>(val.getDefiningOp())) | ||
| val = cst.getIn(); | ||
|
|
||
| return getConstantIntValue(getAsOpFoldResult(val)); | ||
| }; | ||
|
|
||
| auto areLoopBoundKnown = [&](OpFoldResult lb, OpFoldResult ub, | ||
| OpFoldResult step) { | ||
| return (getBoundValue(lb) && getBoundValue(ub) && | ||
| getBoundValue(step)); | ||
| }; | ||
|
|
||
| if (!areLoopBoundKnown(lb, ub, step)) | ||
| return std::nullopt; | ||
|
|
||
| int64_t lbVal = *getBoundValue(lb); | ||
| int64_t ubVal = *getBoundValue(ub); | ||
| int64_t stepVal = *getBoundValue(step); | ||
| int64_t lastIVVal = | ||
| lbVal + ((ubVal - lbVal - 1) / stepVal) * stepVal; | ||
| llvm::APInt start(64, lbVal, true); | ||
| llvm::APInt end(64, lastIVVal, true); | ||
|
|
||
| return ConstantIntRanges::range(start, end, true); | ||
| }; | ||
|
|
||
| std::optional<ConstantIntRanges> optRange = | ||
| computeLoopIVRange(lb, ub, step); | ||
| if (!optRange) { | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is necessary\n"); | ||
| newBoundaryCheck.push_back(boundIdx); | ||
| continue; | ||
| } | ||
|
|
||
| // Compare the max value of the loop IV to the offset. | ||
| APInt max = (*optRange).smax(); | ||
| if (max.getSExtValue() < shapeVal) { | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is unnecessary\n"); | ||
| continue; | ||
| } | ||
| } | ||
|
|
||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is necessary\n"); | ||
| newBoundaryCheck.push_back(boundIdx); | ||
| } | ||
|
|
||
| if (newBoundaryCheck.size() != loadOp.getBoundaryCheck().size()) { | ||
| loadOp.setBoundaryCheck(newBoundaryCheck); | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Rewritten load is: " << loadOp << "\n"); | ||
| } | ||
|
|
||
| return WalkResult::advance(); | ||
| }); | ||
| } | ||
|
|
||
| private: | ||
| // A candidate load operation is one that: | ||
| // - has the boundary check attribute | ||
| // - uses a block pointer defined by a `make_tensor_ptr` that is not | ||
| // advanced | ||
| bool isCandidate(tt::LoadOp loadOp) const { | ||
| assert(loadOp && "Expecting a valid load operation"); | ||
|
|
||
| ArrayRef<int> boundaryCheck = loadOp.getBoundaryCheck(); | ||
| if (boundaryCheck.empty()) | ||
| return false; | ||
|
|
||
| Type ptrType = loadOp.getPtr().getType(); | ||
| if (!tt::isTensorPointerType(ptrType)) | ||
| return false; | ||
|
|
||
| std::optional<tt::MakeTensorPtrOp> makeTensorPtrOp = | ||
| tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); | ||
| if (!makeTensorPtrOp) | ||
| return false; | ||
|
|
||
| if (llvm::any_of((*makeTensorPtrOp)->getUsers(), | ||
etiotto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| [](Operation *user) { return isa<tt::AdvanceOp>(user); })) | ||
| return false; | ||
|
|
||
| return true; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| struct TritonIntelRemoveBoundaryChecks | ||
| : tt::intel::impl::TritonIntelRemoveBoundaryChecksBase< | ||
| TritonIntelRemoveBoundaryChecks> { | ||
| public: | ||
| void runOnOperation() final { | ||
| ModuleOp moduleOp = getOperation(); | ||
| BoundaryChecksRemover remover; | ||
| remover.run(moduleOp); | ||
| assert(succeeded(verify(moduleOp)) && "Module verification failed"); | ||
| } | ||
| }; | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.