diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 4ccfaf9157..1e06829064 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -95,6 +95,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::test::registerTestTritonAMDGPURangeAnalysis(); mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::intel::registerTritonIntelFuseReshape(); + mlir::triton::intel::registerTritonIntelRemoveBoundaryChecks(); mlir::triton::intel::registerTritonIntelRemoveMasks(); mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer(); mlir::triton::registerRelayoutTritonGPUPass(); diff --git a/test/Triton/Intel/RemoveBoundaryChecks/remove-boundary-checks.mlir b/test/Triton/Intel/RemoveBoundaryChecks/remove-boundary-checks.mlir new file mode 100644 index 0000000000..c5fdde2f8f --- /dev/null +++ b/test/Triton/Intel/RemoveBoundaryChecks/remove-boundary-checks.mlir @@ -0,0 +1,58 @@ +// RUN: triton-opt %s -split-input-file -triton-intel-remove-boundary-checks | FileCheck %s + +module { +tt.func public @simple_load(%load_ptr: !tt.ptr {tt.divisibility = 16 : i32}, %store_ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c0_i32 = arith.constant 0 : i32 + %x = arith.constant 10 : i32 + %in = tt.make_tensor_ptr %load_ptr, [%c1_i64, %c64_i64, %c1024_i64], [%c512_i64, %c64_i64, %c1_i64], [%c0_i32, %c0_i32, %x] {order = array} : > + // boundaryCheck is unnecessary because %x + loadResType.shape[2] - 1 = 10 + 64 - 1 = 73 < 1024 + %load = tt.load %in {boundaryCheck = array} : !tt.ptr> + tt.return +} +// CHECK-LABEL: simple_load +// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr +// CHECK: tt.load [[PTR]] : !tt.ptr> +} + +// ----- + +module { +tt.func public @load_in_for_loop(%load_ptr0: !tt.ptr {tt.divisibility = 16 : i32}, %load_ptr1: !tt.ptr {tt.divisibility = 16 : i32}, %store_ptr: !tt.ptr {tt.divisibility = 16 : i32}) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c20_i32 = arith.constant 20 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1024_i32 = arith.constant 1024 : i32 + scf.for %x = %c0_i32 to %c20_i32 step %c1_i32 : i32 { + %pid = tt.get_program_id x : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c64_i64 = arith.constant 64 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %ptr0 = tt.make_tensor_ptr %load_ptr0, [%c512_i64, %c1024_i64, %c64_i64], [%c65536_i64, %c64_i64, %c1_i64], [%x, %pid, %c0_i32] {order = array} : > + %load0 = tt.load %ptr0 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %9 = arith.bitcast %c0_i32 : i32 to i32 + %10 = arith.bitcast %c1024_i32 : i32 to i32 + %11 = arith.bitcast %c64_i32 : i32 to i32 + scf.for %z = %9 to %10 step %11 iter_args() -> () : i32 { + %ptr1 = tt.make_tensor_ptr %load_ptr1, [%c512_i64, %c64_i64, %c1024_i64], [%c65536_i64, %c1_i64, %c64_i64], [%x, %c0_i32, %z] {order = array} : > + // a. boundaryCheck = 1 checks the block ptr offset at index 2 (%z) + // b. boundaryCheck = 2 checks the block ptr offset at index 1 (%y) + // Check (a) is unnecessary because max(%z) + loadResType.shape[2] - 1 = 960 + 64 - 1 = 1023, which is less than 1024. + // Check (b) is unnecessary because max(0) + loadResType.shape[1] - 1 = 0 + 64 -1 = 63, which is less than 64. + %load1 = tt.load %ptr1 {boundaryCheck = array} : !tt.ptr> + } + } + tt.return +} +// CHECK-LABEL: load_in_for_loop +// CHECK-COUNT-2: scf.for +// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr +// CHECK: tt.load [[PTR]] : !tt.ptr> +} diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index eb4c15cb32..8ac4e47d68 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -200,6 +200,7 @@ def make_ttir(mod, metadata, opt): passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm) passes.common.add_cse(pm) passes.common.add_licm(pm) + intel.passes.ttir.add_remove_boundary_checks(pm) intel.passes.ttir.add_remove_masks(pm) intel.passes.ttir.add_fuse_reshape(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/intel/include/Dialect/Triton/Transforms/Passes.td b/third_party/intel/include/Dialect/Triton/Transforms/Passes.td index f83c21dc73..5135416b9d 100644 --- a/third_party/intel/include/Dialect/Triton/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/Triton/Transforms/Passes.td @@ -70,4 +70,41 @@ def TritonIntelFuseReshape ]; } +def TritonIntelRemoveBoundaryChecks + : Pass<"triton-intel-remove-boundary-checks", "mlir::ModuleOp"> { + let summary = "Remove unnecessary boundary checks from load operations (block pointers only)"; + + let description = [{ + This pass attempts to remove boundary checks that aren't necessary in a tt.load operation. + For example, given: + %lb = arith.bitcast %c0_i32 : i32 to i32 + %ub = arith.bitcast %c1024_i32 : i32 to i32 + %st = arith.bitcast %c64_i32 : i32 to i32 + scf.for %iv = %lb to %ub step %st : i32 { + %s0 = arith.constant 512 : i64 + %s1 = arith.constant 64 : i64 + %s2 = arith.constant 1024 : i64 + %a = arith.constant 65536 : i64 + %b = arith.constant 1 : i64 + %b = arith.constant 64 : i64 + %y = arith.constant 0 : i32 + %ptr = tt.make_tensor_ptr %base, [%s0, %s1, %s2], [%a, %b, %c], [%x, %y, %iv] + {order = array} : > + %load = tt.load %ptr {boundaryCheck = array} : !tt.ptr> + ... + // here %ptr is never updated. + } + + The transformation would drop the boundary check on the load operation because: + - `%ptr` is never advanced in the loop + - `%iv` has values [0, 64, 128, ..., 960], max(%iv) = 960 + - `%s2` is equal to 1024 + - the boundary check expression `max(%iv) + load_res.shape_in_dim -1` < `%s2` is true. + }]; + + let dependentDialects = [ + "mlir::triton::TritonDialect" + ]; +} + #endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt index 58061bf4ce..edf14803f8 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt +++ b/third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonIntelTransforms FuseReshape.cpp + RemoveBoundaryChecks.cpp RemoveMasks.cpp TensorDescToBlockPointer.cpp diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveBoundaryChecks.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveBoundaryChecks.cpp new file mode 100644 index 0000000000..23404f9778 --- /dev/null +++ b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveBoundaryChecks.cpp @@ -0,0 +1,190 @@ +#include "intel/include/Dialect/Triton/Transforms/Passes.h" +#include "intel/include/Utils/Utility.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Support/WalkResult.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include +#include + +#define DEBUG_TYPE "triton-intel-remove-boundary-checks" + +using namespace mlir; +namespace tt = mlir::triton; + +namespace mlir::triton::intel { +#define GEN_PASS_DEF_TRITONINTELREMOVEBOUNDARYCHECKS +#include "intel/include/Dialect/Triton/Transforms/Passes.h.inc" +} // namespace mlir::triton::intel + +namespace { +class BoundaryChecksRemover { +public: + void run(ModuleOp moduleOp) { + moduleOp.walk([&](tt::LoadOp loadOp) { + if (!isCandidate(loadOp)) + return WalkResult::skip(); + + tt::MakeTensorPtrOp makeTensorPtrOp = + *tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); + LLVM_DEBUG(llvm::dbgs() + << "Analyzing boundaryCheck for: " << loadOp << "\n"); + + SmallVector newBoundaryCheck; + for (int boundIdx : loadOp.getBoundaryCheck()) { + ArrayRef order = makeTensorPtrOp.getOrder(); + int idx = order.size() - order[boundIdx] - 1; + Value offset = makeTensorPtrOp.getOffsets()[idx]; + Value shape = makeTensorPtrOp.getShape()[idx]; + auto resType = cast(loadOp.getResult().getType()); + ArrayRef resShape = resType.getShape(); + std::optional offsetVal = getConstantIntValue(offset), + shapeVal = getConstantIntValue(shape); + + // If the shape is not known at compile time we cannot determine whether + // the bound check is unnecessary. + if (!shapeVal) { + LLVM_DEBUG(llvm::dbgs().indent(2) + << "Check at index " << boundIdx << " is necessary\n"); + newBoundaryCheck.push_back(boundIdx); + continue; + } + + // Case 1: offset and shape are constant. + if (offsetVal && ((*offsetVal + resShape[idx]) <= *shapeVal)) { + LLVM_DEBUG(llvm::dbgs().indent(2) + << "Check at index " << boundIdx << " is unnecessary\n"); + continue; + } + + // Case 2: analyze boundary check in loops. + if (auto forOp = makeTensorPtrOp->getParentOfType()) { + assert(forOp.getSingleInductionVar() && "Single IV expected"); + Value iv = *forOp.getSingleInductionVar(); + if (offset != iv) { + LLVM_DEBUG(llvm::dbgs().indent(2) + << "Check at index " << boundIdx << " is necessary\n"); + newBoundaryCheck.push_back(boundIdx); + continue; + } + + OpFoldResult lb = *forOp.getSingleLowerBound(); + OpFoldResult ub = *forOp.getSingleUpperBound(); + OpFoldResult step = *forOp.getSingleStep(); + + auto computeLoopIVRange = + [&](OpFoldResult lb, OpFoldResult ub, + OpFoldResult step) -> std::optional { + auto getBoundValue = + [](OpFoldResult bound) -> std::optional { + if (std::optional opVal = getConstantIntValue(bound)) + return *opVal; + + Value val = tt::intel::getFinalValue(cast(bound)); + if (auto cst = dyn_cast(val.getDefiningOp())) + val = cst.getIn(); + + return getConstantIntValue(getAsOpFoldResult(val)); + }; + + auto areLoopBoundKnown = [&](OpFoldResult lb, OpFoldResult ub, + OpFoldResult step) { + return (getBoundValue(lb) && getBoundValue(ub) && + getBoundValue(step)); + }; + + if (!areLoopBoundKnown(lb, ub, step)) + return std::nullopt; + + int64_t lbVal = *getBoundValue(lb); + int64_t ubVal = *getBoundValue(ub); + int64_t stepVal = *getBoundValue(step); + int64_t lastIVVal = + lbVal + ((ubVal - lbVal - 1) / stepVal) * stepVal; + llvm::APInt start(64, lbVal, true); + llvm::APInt end(64, lastIVVal, true); + + return ConstantIntRanges::range(start, end, true); + }; + + std::optional optRange = + computeLoopIVRange(lb, ub, step); + if (!optRange) { + LLVM_DEBUG(llvm::dbgs().indent(2) + << "Check at index " << boundIdx << " is necessary\n"); + newBoundaryCheck.push_back(boundIdx); + continue; + } + + APInt maxIV = (*optRange).smax(); + if (maxIV.getSExtValue() + resShape[idx] <= shapeVal) { + LLVM_DEBUG(llvm::dbgs().indent(2) + << "Check at index " << boundIdx << " is unnecessary\n"); + continue; + } + } + + LLVM_DEBUG(llvm::dbgs().indent(2) + << "Check at index " << boundIdx << " is necessary\n"); + newBoundaryCheck.push_back(boundIdx); + } + + if (newBoundaryCheck.size() != loadOp.getBoundaryCheck().size()) { + loadOp.setBoundaryCheck(newBoundaryCheck); + LLVM_DEBUG(llvm::dbgs().indent(2) + << "Rewritten load is: " << loadOp << "\n"); + } + + return WalkResult::advance(); + }); + } + +private: + // A candidate load operation is one that: + // - has the boundary check attribute + // - uses a block pointer defined by a `make_tensor_ptr` that is not + // advanced + bool isCandidate(tt::LoadOp loadOp) const { + assert(loadOp && "Expecting a valid load operation"); + + ArrayRef boundaryCheck = loadOp.getBoundaryCheck(); + if (boundaryCheck.empty()) + return false; + + Type ptrType = loadOp.getPtr().getType(); + if (!tt::isTensorPointerType(ptrType)) + return false; + + std::optional makeTensorPtrOp = + tt::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); + if (!makeTensorPtrOp) + return false; + + if (llvm::any_of((*makeTensorPtrOp)->getUsers(), + [](Operation *user) { return isa(user); })) + return false; + + return true; + } +}; + +} // namespace + +struct TritonIntelRemoveBoundaryChecks + : tt::intel::impl::TritonIntelRemoveBoundaryChecksBase< + TritonIntelRemoveBoundaryChecks> { +public: + void runOnOperation() final { + ModuleOp moduleOp = getOperation(); + BoundaryChecksRemover remover; + remover.run(moduleOp); + assert(succeeded(verify(moduleOp)) && "Module verification failed"); + } +}; diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 53d6b885a9..628736a9ed 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -58,6 +58,8 @@ static uint32_t findKernels(llvm::Module &M, void init_triton_intel_passes_ttir(py::module &&m) { ADD_PASS_WRAPPER_0("add_convert_tdesc_to_block_pointer", intel::createTritonIntelTensorDescToBlockPointer); + ADD_PASS_WRAPPER_0("add_remove_boundary_checks", + intel::createTritonIntelRemoveBoundaryChecks); ADD_PASS_WRAPPER_0("add_remove_masks", intel::createTritonIntelRemoveMasks); ADD_PASS_WRAPPER_0("add_fuse_reshape", intel::createTritonIntelFuseReshape); }