diff --git a/test/Triton/Intel/RemoveMasks/unnecessary-masks.mlir b/test/Triton/Intel/RemoveMasks/unnecessary-masks.mlir new file mode 100644 index 0000000000..0709d50edf --- /dev/null +++ b/test/Triton/Intel/RemoveMasks/unnecessary-masks.mlir @@ -0,0 +1,63 @@ +// RUN: triton-opt %s -triton-intel-remove-masks | FileCheck %s + +module { + tt.func public @test1(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32}, %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16> + %c576_i32 = arith.constant 576 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_4 = arith.constant dense<9216> : tensor<32x1xi32> + %cst_5 = arith.constant dense<16> : tensor<1x32xi32> + %cst_6 = arith.constant dense<576> : tensor<1x32xi32> + %cst_7 = arith.constant dense<0xFF800000> : tensor<32x32xf32> + %cst_8 = arith.constant dense<16> : tensor<32x1xi32> + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> + %4 = tt.splat %1 : i32 -> tensor<32x1xi32> + %5 = arith.addi %4, %3 : tensor<32x1xi32> + %6 = tt.expand_dims %2 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %7 = arith.remsi %5, %cst_8 : tensor<32x1xi32> + %8 = arith.divsi %5, %cst_8 : tensor<32x1xi32> + %9 = tt.splat %in_ptr1 : !tt.ptr -> tensor<32x1x!tt.ptr> + %10 = tt.addptr %9, %7 : tensor<32x1x!tt.ptr>, tensor<32x1xi32> + %11 = tt.load %10 evictionPolicy = evict_last : tensor<32x1x!tt.ptr> + %12 = arith.extf %11 : tensor<32x1xf16> to tensor<32x1xf32> + %13 = tt.broadcast %7 : tensor<32x1xi32> -> tensor<32x32xi32> + %14 = arith.muli %8, %cst_4 : tensor<32x1xi32> + %15 = tt.broadcast %14 : tensor<32x1xi32> -> tensor<32x32xi32> + %16 = tt.splat %in_ptr0 : !tt.ptr -> tensor<32x32x!tt.ptr> + %17 = tt.broadcast %12 : tensor<32x1xf32> -> tensor<32x32xf32> + %_tmp5 = scf.for %r0_offset = %c0_i32 to %c576_i32 step %c32_i32 iter_args(%_tmp5_9 = %cst_7) -> (tensor<32x32xf32>) : i32 { + %44 = tt.splat %r0_offset : i32 -> tensor<1x32xi32> + %45 = arith.addi %44, %6 : tensor<1x32xi32> + %46 = arith.cmpi slt, %45, %cst_6 : tensor<1x32xi32> + %47 = arith.muli %45, %cst_5 : tensor<1x32xi32> + %48 = tt.broadcast %47 : tensor<1x32xi32> -> tensor<32x32xi32> + %49 = arith.addi %13, %48 : tensor<32x32xi32> + %50 = arith.addi %49, %15 : tensor<32x32xi32> + %51 = tt.addptr %16, %50 : tensor<32x32x!tt.ptr>, tensor<32x32xi32> + %52 = tt.broadcast %46 : tensor<1x32xi1> -> tensor<32x32xi1> + %53 = tt.load %51, %52, %cst evictionPolicy = evict_last : tensor<32x32x!tt.ptr> + %54 = arith.extf %53 : tensor<32x32xf16> to tensor<32x32xf32> + %55 = arith.addf %54, %17 : tensor<32x32xf32> + %mask = arith.cmpf ogt, %_tmp5_9, %55 : tensor<32x32xf32> + %56 = arith.cmpf une, %_tmp5_9, %_tmp5_9 : tensor<32x32xf32> + %mask_10 = arith.ori %mask, %56 : tensor<32x32xi1> + %57 = arith.select %mask_10, %_tmp5_9, %55 : tensor<32x32xi1>, tensor<32x32xf32> + %58 = arith.select %52, %57, %_tmp5_9 : tensor<32x32xi1>, tensor<32x32xf32> + scf.yield %58 : tensor<32x32xf32> + } + tt.return + } + // CHECK: tt.func public @test1([[PARAM_0_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK: scf.for + // CHECK: [[PTR:%.+]] = tt.addptr {{.*}} : tensor<32x32x!tt.ptr>, tensor<32x32xi32> + // CHECK: [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_last : tensor<32x32x!tt.ptr> + // CHECK: arith.extf [[LOAD]] : tensor<32x32xf16> to tensor<32x32xf32> + // CHECK: [[ORI:%.+]] = arith.ori {{.*}} : tensor<32x32xi1> + // CHECK: [[SEL:%.+]] = arith.select [[ORI]], {{.*}}, {{.*}} : tensor<32x32xi1>, tensor<32x32xf32> + // CHECK: scf.yield [[SEL]] : tensor<32x32xf32> + // CHECK: } +} diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp index f2572d1751..78a868e02e 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp @@ -2,10 +2,17 @@ #include "intel/include/Utils/Utility.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Instructions.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "triton-intel-remove-masks" @@ -19,12 +26,41 @@ namespace mlir::triton::intel { namespace { +static Operation *dropMask(Operation *op, bool maskVal) { + assert(op && "Expecting a valid operation"); + + OpBuilder builder(op); + Location loc = op->getLoc(); + TypeSwitch(op) + .Case([&](auto loadOp) { + if (maskVal) { + tt::LoadOp newLoadOp = builder.create( + loc, loadOp.getPtr(), loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile()); + loadOp->replaceAllUsesWith(newLoadOp); + } else { + Value other = loadOp.getOther(); + Operation *cstOp = builder.create(loc, other); + loadOp->replaceAllUsesWith(cstOp); + } + }) + .Case([&](auto selectOp) { + selectOp->replaceAllUsesWith( + (maskVal ? selectOp.getTrueValue() : selectOp.getFalseValue()) + .getDefiningOp()); + }) + .Default([](auto) { + return nullptr; + llvm_unreachable("Unexpected operation"); + }); + + return nullptr; +} + // Abstract base class for mask validators. // Mask validators are used to check whether a given mask has an expected form. // Concrete subclasses provide a member function used to select masked // operations that have a mask in a particular (e.g. desired) form. -// Furthermore concrete mask validators classes might also provide a member -// function class MaskValidatorBase { public: virtual ~MaskValidatorBase() = default; @@ -38,6 +74,112 @@ class MaskValidatorBase { virtual std::string getName() const = 0; }; +// A mask validator which ensures the mask is not necessary. +class RemovableMaskValidator final : public MaskValidatorBase { +public: + virtual bool isValidMask(scf::ForOp &forOp, Value mask) const { + Value finalVal = tt::intel::getFinalValue(mask); + assert(finalVal && "Expecting a valid mask"); + + // Ensure the loop range is known. + std::optional optRange = computeLoopIVRange(forOp); + if (!optRange) + return false; + + if (!finalVal.getDefiningOp() || + !isa(finalVal.getDefiningOp())) + return false; + + auto cmpOp = cast(finalVal.getDefiningOp()); + arith::CmpIPredicate pred = cmpOp.getPredicate(); + if (pred != arith::CmpIPredicate::slt) + return false; + + Value lhs = tt::intel::getFinalValue(cmpOp.getLhs()); + Value rhs = tt::intel::getFinalValue(cmpOp.getRhs()); + Operation *lhsOp = tt::intel::getFinalValue(lhs).getDefiningOp(); + Operation *rhsOp = tt::intel::getFinalValue(rhs).getDefiningOp(); + if (!lhsOp || !rhsOp) + return false; + + auto getIntConstantValue = [](Operation *op) -> std::optional { + DenseElementsAttr constAttr; + if (matchPattern(op, m_Constant(&constAttr))) { + auto attr = constAttr.getSplatValue(); + if (auto intAttr = dyn_cast_or_null(attr)) + return intAttr.getValue(); + } + return std::nullopt; + }; + + // TODO: consider the case where the constant is lhs. + std::optional constIntVal = getIntConstantValue(rhsOp); + if (!constIntVal) + return false; + + if (auto addOp = dyn_cast(lhsOp)) { + Value lhs = tt::intel::getFinalValue(addOp.getLhs()); + Value rhs = tt::intel::getFinalValue(addOp.getRhs()); + if (lhs != forOp.getSingleInductionVar()) + return false; + if (auto makeRangeOp = dyn_cast(rhs.getDefiningOp())) { + APInt maxIV = (*optRange).smax(); + int64_t maxVal = maxIV.getSExtValue() + makeRangeOp.getEnd(); + + auto registerMaskValue = [this](Value mask, bool maskVal) { + for (Operation *user : mask.getUsers()) + opToMaskValue.insert({user, maskVal}); + }; + + if (maxVal <= constIntVal->getSExtValue()) { + registerMaskValue(mask, true); + return true; + } + APInt minIV = (*optRange).smin(); + int64_t minVal = minIV.getSExtValue() + makeRangeOp.getStart(); + if (minVal >= constIntVal->getSExtValue()) { + registerMaskValue(mask, false); + return true; + } + } + } + + return false; + } + + virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const { + return {}; + } + + virtual std::string getName() const { return "RemovableMaskValidator"; } + + bool getMaskValue(Operation *op) const { + assert(opToMaskValue.find(op) != opToMaskValue.end() && "mask not present"); + return opToMaskValue[op]; + } + +private: + mutable std::map opToMaskValue; + + std::optional computeLoopIVRange(scf::ForOp forOp) const { + if (!forOp.getSingleInductionVar()) + return std::nullopt; + + if (std::optional tripCount = forOp.getStaticTripCount()) { + OpFoldResult lb = *forOp.getSingleLowerBound(); + OpFoldResult step = *forOp.getSingleStep(); + int64_t lbVal = *getConstantIntValue(lb); + int64_t stepVal = *getConstantIntValue(step); + int64_t lastIVVal = stepVal * (tripCount->getSExtValue() - 1); + llvm::APInt start(64, lbVal, true); + llvm::APInt end(64, lastIVVal, true); + return ConstantIntRanges::range(start, end, true); + } + + return std::nullopt; + } +}; + // A mask validator which ensures that the mask can be reduced to the form: // `END-1 < N-i*END` class CanonicalMaskValidator final : public MaskValidatorBase { @@ -51,12 +193,14 @@ class CanonicalMaskValidator final : public MaskValidatorBase { // Check whether the mask is equivalent to the form: `END-1 < N-i*END`. virtual bool isValidMask(scf::ForOp &forOp, Value mask) const { - assert(mask && "Expecting a valid mask"); + Value finalVal = tt::intel::getFinalValue(mask); + assert(finalVal && "Expecting a valid mask"); - if (!mask.getDefiningOp() || !isa(mask.getDefiningOp())) + if (!finalVal.getDefiningOp() || + !isa(finalVal.getDefiningOp())) return false; - auto cmpOp = cast(mask.getDefiningOp()); + auto cmpOp = cast(finalVal.getDefiningOp()); arith::CmpIPredicate pred = cmpOp.getPredicate(); if (pred != arith::CmpIPredicate::slt) return false; @@ -103,7 +247,10 @@ class CanonicalMaskValidator final : public MaskValidatorBase { // `(N+END-1)/END` (possibly folded), the versioning condition will be: // `(N+END-1)%END > 0 && N > END`. virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const { - MaskInfo maskInfo = getMaskInfo(forOp, mask); + Value finalVal = tt::intel::getFinalValue(mask); + assert(finalVal && "Expecting a valid mask"); + + MaskInfo maskInfo = getMaskInfo(forOp, finalVal); if (!hasCanonicalUpperBound(forOp, maskInfo)) return nullptr; @@ -203,11 +350,14 @@ class InvariantMaskValidator final : public MaskValidatorBase { // - [0..END] < splat(N) // - splat(N) < [0..END] virtual bool isValidMask(scf::ForOp &forOp, Value mask) const { - assert(mask && "Expecting a valid mask"); - if (!mask.getDefiningOp() || !isa(mask.getDefiningOp())) + Value finalVal = tt::intel::getFinalValue(mask); + assert(finalVal && "Expecting a valid mask"); + + if (!finalVal.getDefiningOp() || + !isa(finalVal.getDefiningOp())) return false; - auto cmpOp = cast(mask.getDefiningOp()); + auto cmpOp = cast(finalVal.getDefiningOp()); arith::CmpIPredicate pred = cmpOp.getPredicate(); if (pred != arith::CmpIPredicate::slt) return false; @@ -242,7 +392,7 @@ class InvariantMaskValidator final : public MaskValidatorBase { } return false; - } + } // namespace virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const { assert(isValidMask(forOp, mask) && "Invalid mask"); @@ -301,11 +451,8 @@ template class MaskedOpsCollector { bool collectMaskedOps() { auto collectMaskedOps = [&](auto ops, MaskedOperations &maskedOps) { for (Operation *op : ops) { - Value mask = isa(op) ? cast(op).getMask() - : isa(op) ? cast(op).getMask() - : nullptr; - if (mask && - maskValidator.isValidMask(forOp, tt::intel::getFinalValue(mask))) { + Value mask = getMask(op); + if (mask && maskValidator.isValidMask(forOp, mask)) { maskedOps.insert(op); LLVM_DEBUG(llvm::dbgs() << maskValidator.getName() @@ -316,12 +463,23 @@ template class MaskedOpsCollector { collectMaskedOps(forOp.getOps(), maskedOps); collectMaskedOps(forOp.getOps(), maskedOps); + collectMaskedOps(forOp.getOps(), maskedOps); return maskedOps.size(); } const MaskedOperations &getMaskedOps() const { return maskedOps; }; const MaskValidator &getMaskValidator() const { return maskValidator; } + Value getMask(Operation *op) const { + assert(op && "Expecting a valid operation"); + return TypeSwitch(op) + .Case( + [](auto maskedOp) { return maskedOp.getMask(); }) + .template Case( + [](auto selectOp) { return selectOp.getCondition(); }) + .Default([](auto) { return nullptr; }); + } + private: scf::ForOp &forOp; MaskValidator &maskValidator; @@ -515,6 +673,28 @@ struct TritonIntelRemoveMasksBase void runOnOperation() final { ModuleOp moduleOp = getOperation(); + // Remove masks if the are not necessary + moduleOp->walk([&](Operation *op) { + if (scf::ForOp forOp = dyn_cast(op)) { + // Nested loop aren't currently handled. + if (forOp->template getParentOfType()) + return WalkResult::advance(); + + if (!forOp.getSingleInductionVar()) + return WalkResult::advance(); + + RemovableMaskValidator maskValidator; + MaskedOpsCollector collector(forOp, maskValidator); + if (collector.collectMaskedOps()) { + for (Operation *op : collector.getMaskedOps()) { + bool maskVal = maskValidator.getMaskValue(op); + dropMask(op, maskVal); + } + } + } + return WalkResult::advance(); + }); + // Version loops containing masked operation in canonical form. moduleOp->walk([&](Operation *op) { if (scf::ForOp forOp = dyn_cast(op)) {