-
Couldn't load subscription status.
- Fork 75
[Helion]: Remove boundaryChecks on load operation using a block ptr/te… #5363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
etiotto
wants to merge
4
commits into
main
Choose a base branch
from
etiotto.remove_boundary_checks
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
test/Triton/Intel/RemoveBoundaryChecks/remove-boundary-checks.mlir
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| // RUN: triton-opt %s -split-input-file -triton-intel-remove-boundary-checks | FileCheck %s | ||
|
|
||
| module { | ||
| tt.func public @simple_load(%load_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %store_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}) { | ||
| %c1_i64 = arith.constant 1 : i64 | ||
| %c64_i64 = arith.constant 64 : i64 | ||
| %c512_i64 = arith.constant 512 : i64 | ||
| %c1024_i64 = arith.constant 1024 : i64 | ||
| %c0_i32 = arith.constant 0 : i32 | ||
| %x = arith.constant 10 : i32 | ||
| %in = tt.make_tensor_ptr %load_ptr, [%c1_i64, %c64_i64, %c1024_i64], [%c512_i64, %c64_i64, %c1_i64], [%c0_i32, %c0_i32, %x] {order = array<i32: 2, 1, 0>} : <tensor<1x64x64xf16>> | ||
| // boundaryCheck is unnecessary because %x + loadResType.shape[2] - 1 = 10 + 64 - 1 = 73 < 1024 | ||
| %load = tt.load %in {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x64x64xf16>> | ||
| tt.return | ||
| } | ||
| // CHECK-LABEL: simple_load | ||
| // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr | ||
| // CHECK: tt.load [[PTR]] : !tt.ptr<tensor<1x64x64xf16>> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| module { | ||
| tt.func public @load_in_for_loop(%load_ptr0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %load_ptr1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %store_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}) { | ||
| %c0_i32 = arith.constant 0 : i32 | ||
| %c1_i32 = arith.constant 1 : i32 | ||
| %c20_i32 = arith.constant 20 : i32 | ||
| %c64_i32 = arith.constant 64 : i32 | ||
| %c1024_i32 = arith.constant 1024 : i32 | ||
| scf.for %x = %c0_i32 to %c20_i32 step %c1_i32 : i32 { | ||
| %pid = tt.get_program_id x : i32 | ||
| %c0_i64 = arith.constant 0 : i64 | ||
| %c1_i64 = arith.constant 1 : i64 | ||
| %c512_i64 = arith.constant 512 : i64 | ||
| %c1024_i64 = arith.constant 1024 : i64 | ||
| %c64_i64 = arith.constant 64 : i64 | ||
| %c65536_i64 = arith.constant 65536 : i64 | ||
| %ptr0 = tt.make_tensor_ptr %load_ptr0, [%c512_i64, %c1024_i64, %c64_i64], [%c65536_i64, %c64_i64, %c1_i64], [%x, %pid, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>> | ||
| %load0 = tt.load %ptr0 {boundaryCheck = array<i32: 1, 2>, padding = 1 : i32} : !tt.ptr<tensor<1x512x64xf16>> | ||
| %9 = arith.bitcast %c0_i32 : i32 to i32 | ||
| %10 = arith.bitcast %c1024_i32 : i32 to i32 | ||
| %11 = arith.bitcast %c64_i32 : i32 to i32 | ||
| scf.for %z = %9 to %10 step %11 iter_args() -> () : i32 { | ||
| %ptr1 = tt.make_tensor_ptr %load_ptr1, [%c512_i64, %c64_i64, %c1024_i64], [%c65536_i64, %c1_i64, %c64_i64], [%x, %c0_i32, %z] {order = array<i32: 2, 0, 1>} : <tensor<1x64x64xf16>> | ||
| // a. boundaryCheck = 1 checks the block ptr offset at index 2 (%z) | ||
| // b. boundaryCheck = 2 checks the block ptr offset at index 1 (%y) | ||
| // Check (a) is unnecessary because max(%z) + loadResType.shape[2] - 1 = 960 + 64 - 1 = 1023, which is less than 1024. | ||
| // Check (b) is unnecessary because max(0) + loadResType.shape[1] - 1 = 0 + 64 -1 = 63, which is less than 64. | ||
| %load1 = tt.load %ptr1 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x64x64xf16>> | ||
| } | ||
| } | ||
| tt.return | ||
| } | ||
| // CHECK-LABEL: load_in_for_loop | ||
| // CHECK-COUNT-2: scf.for | ||
| // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr | ||
| // CHECK: tt.load [[PTR]] : !tt.ptr<tensor<1x64x64xf16>> | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
190 changes: 190 additions & 0 deletions
190
third_party/intel/lib/Dialect/Triton/Transforms/RemoveBoundaryChecks.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| #include "intel/include/Dialect/Triton/Transforms/Passes.h" | ||
| #include "intel/include/Utils/Utility.h" | ||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Dialect/SCF/IR/SCF.h" | ||
| #include "mlir/Dialect/Utils/StaticValueUtils.h" | ||
| #include "mlir/IR/BuiltinTypes.h" | ||
| #include "mlir/IR/Verifier.h" | ||
| #include "mlir/Interfaces/InferIntRangeInterface.h" | ||
| #include "mlir/Support/WalkResult.h" | ||
| #include "triton/Dialect/Triton/IR/Dialect.h" | ||
| #include "llvm/ADT/APInt.h" | ||
| #include "llvm/Support/Debug.h" | ||
| #include "llvm/Support/raw_ostream.h" | ||
| #include <cmath> | ||
| #include <optional> | ||
|
|
||
| #define DEBUG_TYPE "triton-intel-remove-boundary-checks" | ||
|
|
||
| using namespace mlir; | ||
| namespace tt = mlir::triton; | ||
|
|
||
| namespace mlir::triton::intel { | ||
| #define GEN_PASS_DEF_TRITONINTELREMOVEBOUNDARYCHECKS | ||
| #include "intel/include/Dialect/Triton/Transforms/Passes.h.inc" | ||
| } // namespace mlir::triton::intel | ||
|
|
||
| namespace { | ||
| class BoundaryChecksRemover { | ||
| public: | ||
| void run(ModuleOp moduleOp) { | ||
| moduleOp.walk([&](tt::LoadOp loadOp) { | ||
| if (!isCandidate(loadOp)) | ||
| return WalkResult::skip(); | ||
|
|
||
| tt::MakeTensorPtrOp makeTensorPtrOp = | ||
| *tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); | ||
| LLVM_DEBUG(llvm::dbgs() | ||
| << "Analyzing boundaryCheck for: " << loadOp << "\n"); | ||
|
|
||
| SmallVector<int> newBoundaryCheck; | ||
| for (int boundIdx : loadOp.getBoundaryCheck()) { | ||
| ArrayRef<int> order = makeTensorPtrOp.getOrder(); | ||
| int idx = order.size() - order[boundIdx] - 1; | ||
| Value offset = makeTensorPtrOp.getOffsets()[idx]; | ||
| Value shape = makeTensorPtrOp.getShape()[idx]; | ||
| auto resType = cast<RankedTensorType>(loadOp.getResult().getType()); | ||
| ArrayRef<int64_t> resShape = resType.getShape(); | ||
| std::optional<int64_t> offsetVal = getConstantIntValue(offset), | ||
| shapeVal = getConstantIntValue(shape); | ||
|
|
||
| // If the shape is not known at compile time we cannot determine whether | ||
| // the bound check is unnecessary. | ||
| if (!shapeVal) { | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is necessary\n"); | ||
| newBoundaryCheck.push_back(boundIdx); | ||
| continue; | ||
| } | ||
|
|
||
| // Case 1: offset and shape are constant. | ||
| if (offsetVal && ((*offsetVal + resShape[idx]) <= *shapeVal)) { | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is unnecessary\n"); | ||
| continue; | ||
| } | ||
|
|
||
| // Case 2: analyze boundary check in loops. | ||
| if (auto forOp = makeTensorPtrOp->getParentOfType<scf::ForOp>()) { | ||
| assert(forOp.getSingleInductionVar() && "Single IV expected"); | ||
| Value iv = *forOp.getSingleInductionVar(); | ||
| if (offset != iv) { | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is necessary\n"); | ||
| newBoundaryCheck.push_back(boundIdx); | ||
| continue; | ||
| } | ||
|
|
||
| OpFoldResult lb = *forOp.getSingleLowerBound(); | ||
| OpFoldResult ub = *forOp.getSingleUpperBound(); | ||
| OpFoldResult step = *forOp.getSingleStep(); | ||
|
|
||
| auto computeLoopIVRange = | ||
| [&](OpFoldResult lb, OpFoldResult ub, | ||
| OpFoldResult step) -> std::optional<ConstantIntRanges> { | ||
| auto getBoundValue = | ||
| [](OpFoldResult bound) -> std::optional<int64_t> { | ||
| if (std::optional<int64_t> opVal = getConstantIntValue(bound)) | ||
| return *opVal; | ||
|
|
||
| Value val = tt::intel::getFinalValue(cast<Value>(bound)); | ||
| if (auto cst = dyn_cast<arith::BitcastOp>(val.getDefiningOp())) | ||
| val = cst.getIn(); | ||
|
|
||
| return getConstantIntValue(getAsOpFoldResult(val)); | ||
| }; | ||
|
|
||
| auto areLoopBoundKnown = [&](OpFoldResult lb, OpFoldResult ub, | ||
| OpFoldResult step) { | ||
| return (getBoundValue(lb) && getBoundValue(ub) && | ||
| getBoundValue(step)); | ||
| }; | ||
|
|
||
| if (!areLoopBoundKnown(lb, ub, step)) | ||
| return std::nullopt; | ||
|
|
||
| int64_t lbVal = *getBoundValue(lb); | ||
| int64_t ubVal = *getBoundValue(ub); | ||
| int64_t stepVal = *getBoundValue(step); | ||
| int64_t lastIVVal = | ||
| lbVal + ((ubVal - lbVal - 1) / stepVal) * stepVal; | ||
| llvm::APInt start(64, lbVal, true); | ||
| llvm::APInt end(64, lastIVVal, true); | ||
|
|
||
| return ConstantIntRanges::range(start, end, true); | ||
| }; | ||
|
|
||
| std::optional<ConstantIntRanges> optRange = | ||
| computeLoopIVRange(lb, ub, step); | ||
| if (!optRange) { | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is necessary\n"); | ||
| newBoundaryCheck.push_back(boundIdx); | ||
| continue; | ||
| } | ||
|
|
||
| APInt maxIV = (*optRange).smax(); | ||
| if (maxIV.getSExtValue() + resShape[idx] <= shapeVal) { | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is unnecessary\n"); | ||
| continue; | ||
| } | ||
| } | ||
|
|
||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Check at index " << boundIdx << " is necessary\n"); | ||
| newBoundaryCheck.push_back(boundIdx); | ||
| } | ||
|
|
||
| if (newBoundaryCheck.size() != loadOp.getBoundaryCheck().size()) { | ||
| loadOp.setBoundaryCheck(newBoundaryCheck); | ||
| LLVM_DEBUG(llvm::dbgs().indent(2) | ||
| << "Rewritten load is: " << loadOp << "\n"); | ||
| } | ||
|
|
||
| return WalkResult::advance(); | ||
| }); | ||
| } | ||
|
|
||
| private: | ||
| // A candidate load operation is one that: | ||
| // - has the boundary check attribute | ||
| // - uses a block pointer defined by a `make_tensor_ptr` that is not | ||
| // advanced | ||
| bool isCandidate(tt::LoadOp loadOp) const { | ||
| assert(loadOp && "Expecting a valid load operation"); | ||
|
|
||
| ArrayRef<int> boundaryCheck = loadOp.getBoundaryCheck(); | ||
| if (boundaryCheck.empty()) | ||
| return false; | ||
|
|
||
| Type ptrType = loadOp.getPtr().getType(); | ||
| if (!tt::isTensorPointerType(ptrType)) | ||
| return false; | ||
|
|
||
| std::optional<tt::MakeTensorPtrOp> makeTensorPtrOp = | ||
| tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); | ||
| if (!makeTensorPtrOp) | ||
| return false; | ||
|
|
||
| if (llvm::any_of((*makeTensorPtrOp)->getUsers(), | ||
etiotto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| [](Operation *user) { return isa<tt::AdvanceOp>(user); })) | ||
| return false; | ||
|
|
||
| return true; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| struct TritonIntelRemoveBoundaryChecks | ||
| : tt::intel::impl::TritonIntelRemoveBoundaryChecksBase< | ||
| TritonIntelRemoveBoundaryChecks> { | ||
| public: | ||
| void runOnOperation() final { | ||
| ModuleOp moduleOp = getOperation(); | ||
| BoundaryChecksRemover remover; | ||
| remover.run(moduleOp); | ||
| assert(succeeded(verify(moduleOp)) && "Module verification failed"); | ||
| } | ||
| }; | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.