Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::test::registerTestTritonAMDGPURangeAnalysis();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::intel::registerTritonIntelFuseReshape();
mlir::triton::intel::registerTritonIntelRemoveBoundaryChecks();
mlir::triton::intel::registerTritonIntelRemoveMasks();
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
mlir::triton::registerRelayoutTritonGPUPass();
Expand Down
1 change: 1 addition & 0 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def make_ttir(mod, metadata, opt):
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
intel.passes.ttir.add_remove_boundary_checks(pm)
intel.passes.ttir.add_remove_masks(pm)
intel.passes.ttir.add_fuse_reshape(pm)
passes.common.add_canonicalizer(pm)
Expand Down
37 changes: 37 additions & 0 deletions third_party/intel/include/Dialect/Triton/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,41 @@ def TritonIntelFuseReshape
];
}

def TritonIntelRemoveBoundaryChecks
: Pass<"triton-intel-remove-boundary-checks", "mlir::ModuleOp"> {
let summary = "Remove unnecessary boundary checks from load operations (block pointers only)";

let description = [{
This pass attempts to remove boundary checks that aren't necessary in a tt.load operation.
For example, given:
%lb = arith.bitcast %c0_i32 : i32 to i32
%ub = arith.bitcast %c1024_i32 : i32 to i32
%st = arith.bitcast %c64_i32 : i32 to i32
scf.for %i = %lb to %ub step %st : i32 {
%s0 = arith.constant 512 : i64
%s1 = arith.constant 64 : i64
%s2 = arith.constant 1024 : i64
%a = arith.constant 65536 : i64
%b = arith.constant 1 : i64
%b = arith.constant 64 : i64
%y = arith.constant 0 : i32
%ptr = tt.make_tensor_ptr %base, [%s0, %s1, %s2], [%a, %b, %c], [%x, %y, %iv]
{order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
%load = tt.load %ptr {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x512x64xf16>>
...
// here %ptr is never updated.
}

The transformation would drop the boundary check on the load operation because:
- `%ptr` is never advanced in the loop
- `%iv` has values [0, 64, 128, ..., 960], max(%iv) = 960
- `%s2` is qual to 1014
- the boundary check expression `%iv` < `%s2` is always true
}];

let dependentDialects = [
"mlir::triton::TritonDialect"
];
}

#endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_triton_library(TritonIntelTransforms
FuseReshape.cpp
RemoveBoundaryChecks.cpp
RemoveMasks.cpp
TensorDescToBlockPointer.cpp

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@

#include "intel/include/Dialect/Triton/Transforms/Passes.h"
#include "intel/include/Utils/Utility.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Support/WalkResult.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>
#include <optional>

#define DEBUG_TYPE "triton-intel-remove-boundary-checks"

using namespace mlir;
namespace tt = mlir::triton;

namespace mlir::triton::intel {
#define GEN_PASS_DEF_TRITONINTELREMOVEBOUNDARYCHECKS
#include "intel/include/Dialect/Triton/Transforms/Passes.h.inc"
} // namespace mlir::triton::intel

namespace {
class BoundaryChecksRemover {
public:
void run(ModuleOp moduleOp) {
moduleOp.walk([&](tt::LoadOp loadOp) {
if (!isCandidate(loadOp))
return WalkResult::skip();

tt::MakeTensorPtrOp makeTensorPtrOp =
*tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr());
LLVM_DEBUG(llvm::dbgs()
<< "Analyzing boundaryCheck for: " << loadOp << "\n");

SmallVector<int> newBoundaryCheck;
for (int boundIdx : loadOp.getBoundaryCheck()) {
ArrayRef<int> order = makeTensorPtrOp.getOrder();
int idx = order.size() - order[boundIdx] - 1;
Value offset = makeTensorPtrOp.getOffsets()[idx];
Value shape = makeTensorPtrOp.getShape()[idx];
std::optional<int64_t> offsetVal = getConstantIntValue(offset),
shapeVal = getConstantIntValue(shape);

// If the shape is not known at compile time we cannot determine whether
// the bound check is unnecessary.
if (!shapeVal) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is necessary\n");
newBoundaryCheck.push_back(boundIdx);
continue;
}

// Case 1: offset and shape are constant.
if (offsetVal && *offsetVal < *shapeVal) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is unnecessary\n");
continue;
}

// Case 2: analyze boundary check in loops.
if (auto forOp = makeTensorPtrOp->getParentOfType<scf::ForOp>()) {
assert(forOp.getSingleInductionVar() && "Single IV expected");
Value iv = *forOp.getSingleInductionVar();
if (offset != iv) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is necessary\n");
newBoundaryCheck.push_back(boundIdx);
continue;
}

OpFoldResult lb = *forOp.getSingleLowerBound();
OpFoldResult ub = *forOp.getSingleUpperBound();
OpFoldResult step = *forOp.getSingleStep();

auto computeLoopIVRange =
[&](OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) -> std::optional<ConstantIntRanges> {
auto getBoundValue =
[](OpFoldResult bound) -> std::optional<int64_t> {
if (std::optional<int64_t> opVal = getConstantIntValue(bound))
return *opVal;

Value val = tt::intel::getFinalValue(cast<Value>(bound));
if (auto cst = dyn_cast<arith::BitcastOp>(val.getDefiningOp()))
val = cst.getIn();

return getConstantIntValue(getAsOpFoldResult(val));
};

auto areLoopBoundKnown = [&](OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
return (getBoundValue(lb) && getBoundValue(ub) &&
getBoundValue(step));
};

if (!areLoopBoundKnown(lb, ub, step))
return std::nullopt;

int64_t lbVal = *getBoundValue(lb);
int64_t ubVal = *getBoundValue(ub);
int64_t stepVal = *getBoundValue(step);
int64_t lastIVVal =
lbVal + ((ubVal - lbVal - 1) / stepVal) * stepVal;
llvm::APInt start(64, lbVal, true);
llvm::APInt end(64, lastIVVal, true);

return ConstantIntRanges::range(start, end, true);
};

std::optional<ConstantIntRanges> optRange =
computeLoopIVRange(lb, ub, step);
if (!optRange) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is necessary\n");
newBoundaryCheck.push_back(boundIdx);
continue;
}

// Compare the max value of the loop IV to the offset.
APInt max = (*optRange).smax();
if (max.getSExtValue() < shapeVal) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is unnecessary\n");
continue;
}
}

LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is necessary\n");
newBoundaryCheck.push_back(boundIdx);
}

if (newBoundaryCheck.size() != loadOp.getBoundaryCheck().size()) {
loadOp.setBoundaryCheck(newBoundaryCheck);
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Rewritten load is: " << loadOp << "\n");
}

return WalkResult::advance();
});
}

private:
// A candidate load operation is one that:
// - has the boundary check attribute
// - uses a block pointer defined by a `make_tensor_ptr` that is not
// advanced
bool isCandidate(tt::LoadOp loadOp) const {
assert(loadOp && "Expecting a valid load operation");

ArrayRef<int> boundaryCheck = loadOp.getBoundaryCheck();
if (boundaryCheck.empty())
return false;

Type ptrType = loadOp.getPtr().getType();
if (!tt::isTensorPointerType(ptrType))
return false;

std::optional<tt::MakeTensorPtrOp> makeTensorPtrOp =
tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr());
if (!makeTensorPtrOp)
return false;

if (llvm::any_of((*makeTensorPtrOp)->getUsers(),
[](Operation *user) { return isa<tt::AdvanceOp>(user); }))
return false;

return true;
}
};

} // namespace

struct TritonIntelRemoveBoundaryChecks
: tt::intel::impl::TritonIntelRemoveBoundaryChecksBase<
TritonIntelRemoveBoundaryChecks> {
public:
void runOnOperation() final {
ModuleOp moduleOp = getOperation();
BoundaryChecksRemover remover;
remover.run(moduleOp);
assert(succeeded(verify(moduleOp)) && "Module verification failed");
}
};
2 changes: 2 additions & 0 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ static uint32_t findKernels(llvm::Module &M,
void init_triton_intel_passes_ttir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_convert_tdesc_to_block_pointer",
intel::createTritonIntelTensorDescToBlockPointer);
ADD_PASS_WRAPPER_0("add_remove_boundary_checks",
intel::createTritonIntelRemoveBoundaryChecks);
ADD_PASS_WRAPPER_0("add_remove_masks", intel::createTritonIntelRemoveMasks);
ADD_PASS_WRAPPER_0("add_fuse_reshape", intel::createTritonIntelFuseReshape);
}
Expand Down
Loading