Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::test::registerTestTritonAMDGPURangeAnalysis();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::intel::registerTritonIntelFuseReshape();
mlir::triton::intel::registerTritonIntelRemoveBoundaryChecks();
mlir::triton::intel::registerTritonIntelRemoveMasks();
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
mlir::triton::registerRelayoutTritonGPUPass();
Expand Down
58 changes: 58 additions & 0 deletions test/Triton/Intel/RemoveBoundaryChecks/remove-boundary-checks.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// RUN: triton-opt %s -split-input-file -triton-intel-remove-boundary-checks | FileCheck %s

module {
tt.func public @simple_load(%load_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %store_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%c1_i64 = arith.constant 1 : i64
%c64_i64 = arith.constant 64 : i64
%c512_i64 = arith.constant 512 : i64
%c1024_i64 = arith.constant 1024 : i64
%c0_i32 = arith.constant 0 : i32
%x = arith.constant 10 : i32
%in = tt.make_tensor_ptr %load_ptr, [%c1_i64, %c64_i64, %c1024_i64], [%c512_i64, %c64_i64, %c1_i64], [%c0_i32, %c0_i32, %x] {order = array<i32: 2, 1, 0>} : <tensor<1x64x64xf16>>
// boundaryCheck is unnecessary because %x + loadResType.shape[2] - 1 = 10 + 64 - 1 = 73 < 1024
%load = tt.load %in {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x64x64xf16>>
tt.return
}
// CHECK-LABEL: simple_load
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr
// CHECK: tt.load [[PTR]] : !tt.ptr<tensor<1x64x64xf16>>
}

// -----

module {
tt.func public @load_in_for_loop(%load_ptr0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %load_ptr1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %store_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c20_i32 = arith.constant 20 : i32
%c64_i32 = arith.constant 64 : i32
%c1024_i32 = arith.constant 1024 : i32
scf.for %x = %c0_i32 to %c20_i32 step %c1_i32 : i32 {
%pid = tt.get_program_id x : i32
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%c512_i64 = arith.constant 512 : i64
%c1024_i64 = arith.constant 1024 : i64
%c64_i64 = arith.constant 64 : i64
%c65536_i64 = arith.constant 65536 : i64
%ptr0 = tt.make_tensor_ptr %load_ptr0, [%c512_i64, %c1024_i64, %c64_i64], [%c65536_i64, %c64_i64, %c1_i64], [%x, %pid, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
%load0 = tt.load %ptr0 {boundaryCheck = array<i32: 1, 2>, padding = 1 : i32} : !tt.ptr<tensor<1x512x64xf16>>
%9 = arith.bitcast %c0_i32 : i32 to i32
%10 = arith.bitcast %c1024_i32 : i32 to i32
%11 = arith.bitcast %c64_i32 : i32 to i32
scf.for %z = %9 to %10 step %11 iter_args() -> () : i32 {
%ptr1 = tt.make_tensor_ptr %load_ptr1, [%c512_i64, %c64_i64, %c1024_i64], [%c65536_i64, %c1_i64, %c64_i64], [%x, %c0_i32, %z] {order = array<i32: 2, 0, 1>} : <tensor<1x64x64xf16>>
// a. boundaryCheck = 1 checks the block ptr offset at index 2 (%z)
// b. boundaryCheck = 2 checks the block ptr offset at index 1 (%y)
// Check (a) is unnecessary because max(%z) + loadResType.shape[2] - 1 = 960 + 64 - 1 = 1023, which is less than 1024.
// Check (b) is unnecessary because max(0) + loadResType.shape[1] - 1 = 0 + 64 -1 = 63, which is less than 64.
%load1 = tt.load %ptr1 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x64x64xf16>>
}
}
tt.return
}
// CHECK-LABEL: load_in_for_loop
// CHECK-COUNT-2: scf.for
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr
// CHECK: tt.load [[PTR]] : !tt.ptr<tensor<1x64x64xf16>>
}
1 change: 1 addition & 0 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def make_ttir(mod, metadata, opt):
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
intel.passes.ttir.add_remove_boundary_checks(pm)
intel.passes.ttir.add_remove_masks(pm)
intel.passes.ttir.add_fuse_reshape(pm)
passes.common.add_canonicalizer(pm)
Expand Down
37 changes: 37 additions & 0 deletions third_party/intel/include/Dialect/Triton/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,41 @@ def TritonIntelFuseReshape
];
}

def TritonIntelRemoveBoundaryChecks
: Pass<"triton-intel-remove-boundary-checks", "mlir::ModuleOp"> {
let summary = "Remove unnecessary boundary checks from load operations (block pointers only)";

let description = [{
This pass attempts to remove boundary checks that aren't necessary in a tt.load operation.
For example, given:
%lb = arith.bitcast %c0_i32 : i32 to i32
%ub = arith.bitcast %c1024_i32 : i32 to i32
%st = arith.bitcast %c64_i32 : i32 to i32
scf.for %iv = %lb to %ub step %st : i32 {
%s0 = arith.constant 512 : i64
%s1 = arith.constant 64 : i64
%s2 = arith.constant 1024 : i64
%a = arith.constant 65536 : i64
%b = arith.constant 1 : i64
%b = arith.constant 64 : i64
%y = arith.constant 0 : i32
%ptr = tt.make_tensor_ptr %base, [%s0, %s1, %s2], [%a, %b, %c], [%x, %y, %iv]
{order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
%load = tt.load %ptr {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x512x64xf16>>
...
// here %ptr is never updated.
}

The transformation would drop the boundary check on the load operation because:
- `%ptr` is never advanced in the loop
- `%iv` has values [0, 64, 128, ..., 960], max(%iv) = 960
- `%s2` is equal to 1024
- the boundary check expression `max(%iv) + load_res.shape_in_dim -1` < `%s2` is true.
}];

let dependentDialects = [
"mlir::triton::TritonDialect"
];
}

#endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_triton_library(TritonIntelTransforms
FuseReshape.cpp
RemoveBoundaryChecks.cpp
RemoveMasks.cpp
TensorDescToBlockPointer.cpp

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
#include "intel/include/Dialect/Triton/Transforms/Passes.h"
#include "intel/include/Utils/Utility.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Support/WalkResult.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>
#include <optional>

#define DEBUG_TYPE "triton-intel-remove-boundary-checks"

using namespace mlir;
namespace tt = mlir::triton;

namespace mlir::triton::intel {
#define GEN_PASS_DEF_TRITONINTELREMOVEBOUNDARYCHECKS
#include "intel/include/Dialect/Triton/Transforms/Passes.h.inc"
} // namespace mlir::triton::intel

namespace {
class BoundaryChecksRemover {
public:
void run(ModuleOp moduleOp) {
moduleOp.walk([&](tt::LoadOp loadOp) {
if (!isCandidate(loadOp))
return WalkResult::skip();

tt::MakeTensorPtrOp makeTensorPtrOp =
*tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr());
LLVM_DEBUG(llvm::dbgs()
<< "Analyzing boundaryCheck for: " << loadOp << "\n");

SmallVector<int> newBoundaryCheck;
for (int boundIdx : loadOp.getBoundaryCheck()) {
ArrayRef<int> order = makeTensorPtrOp.getOrder();
int idx = order.size() - order[boundIdx] - 1;
Value offset = makeTensorPtrOp.getOffsets()[idx];
Value shape = makeTensorPtrOp.getShape()[idx];
auto resType = cast<RankedTensorType>(loadOp.getResult().getType());
ArrayRef<int64_t> resShape = resType.getShape();
std::optional<int64_t> offsetVal = getConstantIntValue(offset),
shapeVal = getConstantIntValue(shape);

// If the shape is not known at compile time we cannot determine whether
// the bound check is unnecessary.
if (!shapeVal) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is necessary\n");
newBoundaryCheck.push_back(boundIdx);
continue;
}

// Case 1: offset and shape are constant.
if (offsetVal && ((*offsetVal + resShape[idx]) <= *shapeVal)) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is unnecessary\n");
continue;
}

// Case 2: analyze boundary check in loops.
if (auto forOp = makeTensorPtrOp->getParentOfType<scf::ForOp>()) {
assert(forOp.getSingleInductionVar() && "Single IV expected");
Value iv = *forOp.getSingleInductionVar();
if (offset != iv) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is necessary\n");
newBoundaryCheck.push_back(boundIdx);
continue;
}

OpFoldResult lb = *forOp.getSingleLowerBound();
OpFoldResult ub = *forOp.getSingleUpperBound();
OpFoldResult step = *forOp.getSingleStep();

auto computeLoopIVRange =
[&](OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) -> std::optional<ConstantIntRanges> {
auto getBoundValue =
[](OpFoldResult bound) -> std::optional<int64_t> {
if (std::optional<int64_t> opVal = getConstantIntValue(bound))
return *opVal;

Value val = tt::intel::getFinalValue(cast<Value>(bound));
if (auto cst = dyn_cast<arith::BitcastOp>(val.getDefiningOp()))
val = cst.getIn();

return getConstantIntValue(getAsOpFoldResult(val));
};

auto areLoopBoundKnown = [&](OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
return (getBoundValue(lb) && getBoundValue(ub) &&
getBoundValue(step));
};

if (!areLoopBoundKnown(lb, ub, step))
return std::nullopt;

int64_t lbVal = *getBoundValue(lb);
int64_t ubVal = *getBoundValue(ub);
int64_t stepVal = *getBoundValue(step);
int64_t lastIVVal =
lbVal + ((ubVal - lbVal - 1) / stepVal) * stepVal;
llvm::APInt start(64, lbVal, true);
llvm::APInt end(64, lastIVVal, true);

return ConstantIntRanges::range(start, end, true);
};

std::optional<ConstantIntRanges> optRange =
computeLoopIVRange(lb, ub, step);
if (!optRange) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is necessary\n");
newBoundaryCheck.push_back(boundIdx);
continue;
}

APInt maxIV = (*optRange).smax();
if (maxIV.getSExtValue() + resShape[idx] <= shapeVal) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is unnecessary\n");
continue;
}
}

LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Check at index " << boundIdx << " is necessary\n");
newBoundaryCheck.push_back(boundIdx);
}

if (newBoundaryCheck.size() != loadOp.getBoundaryCheck().size()) {
loadOp.setBoundaryCheck(newBoundaryCheck);
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Rewritten load is: " << loadOp << "\n");
}

return WalkResult::advance();
});
}

private:
// A candidate load operation is one that:
// - has the boundary check attribute
// - uses a block pointer defined by a `make_tensor_ptr` that is not
// advanced
bool isCandidate(tt::LoadOp loadOp) const {
assert(loadOp && "Expecting a valid load operation");

ArrayRef<int> boundaryCheck = loadOp.getBoundaryCheck();
if (boundaryCheck.empty())
return false;

Type ptrType = loadOp.getPtr().getType();
if (!tt::isTensorPointerType(ptrType))
return false;

std::optional<tt::MakeTensorPtrOp> makeTensorPtrOp =
tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr());
if (!makeTensorPtrOp)
return false;

if (llvm::any_of((*makeTensorPtrOp)->getUsers(),
[](Operation *user) { return isa<tt::AdvanceOp>(user); }))
return false;

return true;
}
};

} // namespace

struct TritonIntelRemoveBoundaryChecks
: tt::intel::impl::TritonIntelRemoveBoundaryChecksBase<
TritonIntelRemoveBoundaryChecks> {
public:
void runOnOperation() final {
ModuleOp moduleOp = getOperation();
BoundaryChecksRemover remover;
remover.run(moduleOp);
assert(succeeded(verify(moduleOp)) && "Module verification failed");
}
};
2 changes: 2 additions & 0 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ static uint32_t findKernels(llvm::Module &M,
void init_triton_intel_passes_ttir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_convert_tdesc_to_block_pointer",
intel::createTritonIntelTensorDescToBlockPointer);
ADD_PASS_WRAPPER_0("add_remove_boundary_checks",
intel::createTritonIntelRemoveBoundaryChecks);
ADD_PASS_WRAPPER_0("add_remove_masks", intel::createTritonIntelRemoveMasks);
ADD_PASS_WRAPPER_0("add_fuse_reshape", intel::createTritonIntelFuseReshape);
}
Expand Down