Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/Triton/Intel/RemoveMasks/unnecessary-masks.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: triton-opt %s -triton-intel-remove-masks | FileCheck %s

module {
tt.func public @test1(%in_ptr0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %in_ptr1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16>
%c576_i32 = arith.constant 576 : i32
%c0_i32 = arith.constant 0 : i32
%cst_4 = arith.constant dense<9216> : tensor<32x1xi32>
%cst_5 = arith.constant dense<16> : tensor<1x32xi32>
%cst_6 = arith.constant dense<576> : tensor<1x32xi32>
%cst_7 = arith.constant dense<0xFF800000> : tensor<32x32xf32>
%cst_8 = arith.constant dense<16> : tensor<32x1xi32>
%c32_i32 = arith.constant 32 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c32_i32 : i32
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32>
%4 = tt.splat %1 : i32 -> tensor<32x1xi32>
%5 = arith.addi %4, %3 : tensor<32x1xi32>
%6 = tt.expand_dims %2 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
%7 = arith.remsi %5, %cst_8 : tensor<32x1xi32>
%8 = arith.divsi %5, %cst_8 : tensor<32x1xi32>
%9 = tt.splat %in_ptr1 : !tt.ptr<f16> -> tensor<32x1x!tt.ptr<f16>>
%10 = tt.addptr %9, %7 : tensor<32x1x!tt.ptr<f16>>, tensor<32x1xi32>
%11 = tt.load %10 evictionPolicy = evict_last : tensor<32x1x!tt.ptr<f16>>
%12 = arith.extf %11 : tensor<32x1xf16> to tensor<32x1xf32>
%13 = tt.broadcast %7 : tensor<32x1xi32> -> tensor<32x32xi32>
%14 = arith.muli %8, %cst_4 : tensor<32x1xi32>
%15 = tt.broadcast %14 : tensor<32x1xi32> -> tensor<32x32xi32>
%16 = tt.splat %in_ptr0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>>
%17 = tt.broadcast %12 : tensor<32x1xf32> -> tensor<32x32xf32>
%_tmp5 = scf.for %r0_offset = %c0_i32 to %c576_i32 step %c32_i32 iter_args(%_tmp5_9 = %cst_7) -> (tensor<32x32xf32>) : i32 {
%44 = tt.splat %r0_offset : i32 -> tensor<1x32xi32>
%45 = arith.addi %44, %6 : tensor<1x32xi32>
%46 = arith.cmpi slt, %45, %cst_6 : tensor<1x32xi32>
%47 = arith.muli %45, %cst_5 : tensor<1x32xi32>
%48 = tt.broadcast %47 : tensor<1x32xi32> -> tensor<32x32xi32>
%49 = arith.addi %13, %48 : tensor<32x32xi32>
%50 = arith.addi %49, %15 : tensor<32x32xi32>
%51 = tt.addptr %16, %50 : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32>
%52 = tt.broadcast %46 : tensor<1x32xi1> -> tensor<32x32xi1>
%53 = tt.load %51, %52, %cst evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
%54 = arith.extf %53 : tensor<32x32xf16> to tensor<32x32xf32>
%55 = arith.addf %54, %17 : tensor<32x32xf32>
%mask = arith.cmpf ogt, %_tmp5_9, %55 : tensor<32x32xf32>
%56 = arith.cmpf une, %_tmp5_9, %_tmp5_9 : tensor<32x32xf32>
%mask_10 = arith.ori %mask, %56 : tensor<32x32xi1>
%57 = arith.select %mask_10, %_tmp5_9, %55 : tensor<32x32xi1>, tensor<32x32xf32>
%58 = arith.select %52, %57, %_tmp5_9 : tensor<32x32xi1>, tensor<32x32xf32>
scf.yield %58 : tensor<32x32xf32>
}
tt.return
}
// CHECK: tt.func public @test1([[PARAM_0_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
// CHECK: scf.for
// CHECK: [[PTR:%.+]] = tt.addptr {{.*}} : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32>
// CHECK: [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
// CHECK: arith.extf [[LOAD]] : tensor<32x32xf16> to tensor<32x32xf32>
// CHECK: [[ORI:%.+]] = arith.ori {{.*}} : tensor<32x32xi1>
// CHECK: [[SEL:%.+]] = arith.select [[ORI]], {{.*}}, {{.*}} : tensor<32x32xi1>, tensor<32x32xf32>
// CHECK: scf.yield [[SEL]] : tensor<32x32xf32>
// CHECK: }
}
210 changes: 195 additions & 15 deletions third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
#include "intel/include/Utils/Utility.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>

#define DEBUG_TYPE "triton-intel-remove-masks"

Expand All @@ -19,12 +26,41 @@ namespace mlir::triton::intel {

namespace {

static Operation *dropMask(Operation *op, bool maskVal) {
assert(op && "Expecting a valid operation");

OpBuilder builder(op);
Location loc = op->getLoc();
TypeSwitch<Operation *>(op)
.Case<tt::LoadOp>([&](auto loadOp) {
if (maskVal) {
tt::LoadOp newLoadOp = builder.create<tt::LoadOp>(
loc, loadOp.getPtr(), loadOp.getCache(), loadOp.getEvict(),
loadOp.getIsVolatile());
loadOp->replaceAllUsesWith(newLoadOp);
} else {
Value other = loadOp.getOther();
Operation *cstOp = builder.create<arith::ConstantOp>(loc, other);
loadOp->replaceAllUsesWith(cstOp);
}
})
.Case<arith::SelectOp>([&](auto selectOp) {
selectOp->replaceAllUsesWith(
(maskVal ? selectOp.getTrueValue() : selectOp.getFalseValue())
.getDefiningOp());
})
.Default([](auto) {
return nullptr;
llvm_unreachable("Unexpected operation");
});

return nullptr;
}

// Abstract base class for mask validators.
// Mask validators are used to check whether a given mask has an expected form.
// Concrete subclasses provide a member function used to select masked
// operations that have a mask in a particular (e.g. desired) form.
// Furthermore concrete mask validators classes might also provide a member
// function
class MaskValidatorBase {
public:
virtual ~MaskValidatorBase() = default;
Expand All @@ -38,6 +74,112 @@ class MaskValidatorBase {
virtual std::string getName() const = 0;
};

// A mask validator which ensures the mask is not necessary.
class RemovableMaskValidator final : public MaskValidatorBase {
public:
virtual bool isValidMask(scf::ForOp &forOp, Value mask) const {
Value finalVal = tt::intel::getFinalValue(mask);
assert(finalVal && "Expecting a valid mask");

// Ensure the loop range is known.
std::optional<ConstantIntRanges> optRange = computeLoopIVRange(forOp);
if (!optRange)
return false;

if (!finalVal.getDefiningOp() ||
!isa<arith::CmpIOp>(finalVal.getDefiningOp()))
return false;

auto cmpOp = cast<arith::CmpIOp>(finalVal.getDefiningOp());
arith::CmpIPredicate pred = cmpOp.getPredicate();
if (pred != arith::CmpIPredicate::slt)
return false;

Value lhs = tt::intel::getFinalValue(cmpOp.getLhs());
Value rhs = tt::intel::getFinalValue(cmpOp.getRhs());
Operation *lhsOp = tt::intel::getFinalValue(lhs).getDefiningOp();
Operation *rhsOp = tt::intel::getFinalValue(rhs).getDefiningOp();
if (!lhsOp || !rhsOp)
return false;

auto getIntConstantValue = [](Operation *op) -> std::optional<APInt> {
DenseElementsAttr constAttr;
if (matchPattern(op, m_Constant(&constAttr))) {
auto attr = constAttr.getSplatValue<Attribute>();
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue();
}
return std::nullopt;
};

// TODO: consider the case where the constant is lhs.
std::optional<APInt> constIntVal = getIntConstantValue(rhsOp);
if (!constIntVal)
return false;

if (auto addOp = dyn_cast<arith::AddIOp>(lhsOp)) {
Value lhs = tt::intel::getFinalValue(addOp.getLhs());
Value rhs = tt::intel::getFinalValue(addOp.getRhs());
if (lhs != forOp.getSingleInductionVar())
return false;
if (auto makeRangeOp = dyn_cast<tt::MakeRangeOp>(rhs.getDefiningOp())) {
APInt maxIV = (*optRange).smax();
int64_t maxVal = maxIV.getSExtValue() + makeRangeOp.getEnd();

auto registerMaskValue = [this](Value mask, bool maskVal) {
for (Operation *user : mask.getUsers())
opToMaskValue.insert({user, maskVal});
};

if (maxVal <= constIntVal->getSExtValue()) {
registerMaskValue(mask, true);
return true;
}
APInt minIV = (*optRange).smin();
int64_t minVal = minIV.getSExtValue() + makeRangeOp.getStart();
if (minVal >= constIntVal->getSExtValue()) {
registerMaskValue(mask, false);
return true;
}
}
}

return false;
}

virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const {
return {};
}

virtual std::string getName() const { return "RemovableMaskValidator"; }

bool getMaskValue(Operation *op) const {
assert(opToMaskValue.find(op) != opToMaskValue.end() && "mask not present");
return opToMaskValue[op];
}

private:
mutable std::map<Operation *, bool> opToMaskValue;

std::optional<ConstantIntRanges> computeLoopIVRange(scf::ForOp forOp) const {
if (!forOp.getSingleInductionVar())
return std::nullopt;

if (std::optional<APInt> tripCount = forOp.getStaticTripCount()) {
OpFoldResult lb = *forOp.getSingleLowerBound();
OpFoldResult step = *forOp.getSingleStep();
int64_t lbVal = *getConstantIntValue(lb);
int64_t stepVal = *getConstantIntValue(step);
int64_t lastIVVal = stepVal * (tripCount->getSExtValue() - 1);
llvm::APInt start(64, lbVal, true);
llvm::APInt end(64, lastIVVal, true);
return ConstantIntRanges::range(start, end, true);
}

return std::nullopt;
}
};

// A mask validator which ensures that the mask can be reduced to the form:
// `END-1 < N-i*END`
class CanonicalMaskValidator final : public MaskValidatorBase {
Expand All @@ -51,12 +193,14 @@ class CanonicalMaskValidator final : public MaskValidatorBase {

// Check whether the mask is equivalent to the form: `END-1 < N-i*END`.
virtual bool isValidMask(scf::ForOp &forOp, Value mask) const {
assert(mask && "Expecting a valid mask");
Value finalVal = tt::intel::getFinalValue(mask);
assert(finalVal && "Expecting a valid mask");

if (!mask.getDefiningOp() || !isa<arith::CmpIOp>(mask.getDefiningOp()))
if (!finalVal.getDefiningOp() ||
!isa<arith::CmpIOp>(finalVal.getDefiningOp()))
return false;

auto cmpOp = cast<arith::CmpIOp>(mask.getDefiningOp());
auto cmpOp = cast<arith::CmpIOp>(finalVal.getDefiningOp());
arith::CmpIPredicate pred = cmpOp.getPredicate();
if (pred != arith::CmpIPredicate::slt)
return false;
Expand Down Expand Up @@ -103,7 +247,10 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
// `(N+END-1)/END` (possibly folded), the versioning condition will be:
// `(N+END-1)%END > 0 && N > END`.
virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const {
MaskInfo maskInfo = getMaskInfo(forOp, mask);
Value finalVal = tt::intel::getFinalValue(mask);
assert(finalVal && "Expecting a valid mask");

MaskInfo maskInfo = getMaskInfo(forOp, finalVal);
if (!hasCanonicalUpperBound(forOp, maskInfo))
return nullptr;

Expand Down Expand Up @@ -203,11 +350,14 @@ class InvariantMaskValidator final : public MaskValidatorBase {
// - [0..END] < splat(N)
// - splat(N) < [0..END]
virtual bool isValidMask(scf::ForOp &forOp, Value mask) const {
assert(mask && "Expecting a valid mask");
if (!mask.getDefiningOp() || !isa<arith::CmpIOp>(mask.getDefiningOp()))
Value finalVal = tt::intel::getFinalValue(mask);
assert(finalVal && "Expecting a valid mask");

if (!finalVal.getDefiningOp() ||
!isa<arith::CmpIOp>(finalVal.getDefiningOp()))
return false;

auto cmpOp = cast<arith::CmpIOp>(mask.getDefiningOp());
auto cmpOp = cast<arith::CmpIOp>(finalVal.getDefiningOp());
arith::CmpIPredicate pred = cmpOp.getPredicate();
if (pred != arith::CmpIPredicate::slt)
return false;
Expand Down Expand Up @@ -242,7 +392,7 @@ class InvariantMaskValidator final : public MaskValidatorBase {
}

return false;
}
} // namespace

virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const {
assert(isValidMask(forOp, mask) && "Invalid mask");
Expand Down Expand Up @@ -301,11 +451,8 @@ template <typename MaskValidator> class MaskedOpsCollector {
bool collectMaskedOps() {
auto collectMaskedOps = [&](auto ops, MaskedOperations &maskedOps) {
for (Operation *op : ops) {
Value mask = isa<tt::LoadOp>(op) ? cast<tt::LoadOp>(op).getMask()
: isa<tt::StoreOp>(op) ? cast<tt::StoreOp>(op).getMask()
: nullptr;
if (mask &&
maskValidator.isValidMask(forOp, tt::intel::getFinalValue(mask))) {
Value mask = getMask(op);
if (mask && maskValidator.isValidMask(forOp, mask)) {
maskedOps.insert(op);
LLVM_DEBUG(llvm::dbgs()
<< maskValidator.getName()
Expand All @@ -316,12 +463,23 @@ template <typename MaskValidator> class MaskedOpsCollector {

collectMaskedOps(forOp.getOps<tt::LoadOp>(), maskedOps);
collectMaskedOps(forOp.getOps<tt::StoreOp>(), maskedOps);
collectMaskedOps(forOp.getOps<arith::SelectOp>(), maskedOps);
return maskedOps.size();
}

const MaskedOperations &getMaskedOps() const { return maskedOps; };
const MaskValidator &getMaskValidator() const { return maskValidator; }

Value getMask(Operation *op) const {
assert(op && "Expecting a valid operation");
return TypeSwitch<Operation *, Value>(op)
.Case<tt::LoadOp, tt::StoreOp>(
[](auto maskedOp) { return maskedOp.getMask(); })
.template Case<arith::SelectOp>(
[](auto selectOp) { return selectOp.getCondition(); })
.Default([](auto) { return nullptr; });
}

private:
scf::ForOp &forOp;
MaskValidator &maskValidator;
Expand Down Expand Up @@ -515,6 +673,28 @@ struct TritonIntelRemoveMasksBase
void runOnOperation() final {
ModuleOp moduleOp = getOperation();

// Remove masks if the are not necessary
moduleOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (scf::ForOp forOp = dyn_cast<scf::ForOp>(op)) {
// Nested loop aren't currently handled.
if (forOp->template getParentOfType<scf::ForOp>())
return WalkResult::advance();

if (!forOp.getSingleInductionVar())
return WalkResult::advance();

RemovableMaskValidator maskValidator;
MaskedOpsCollector collector(forOp, maskValidator);
if (collector.collectMaskedOps()) {
for (Operation *op : collector.getMaskedOps()) {
bool maskVal = maskValidator.getMaskValue(op);
dropMask(op, maskVal);
}
}
}
return WalkResult::advance();
});

// Version loops containing masked operation in canonical form.
moduleOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (scf::ForOp forOp = dyn_cast<scf::ForOp>(op)) {
Expand Down
Loading