Skip to content

Commit af1fd37

Browse files
committed
[RemoveMasks}: Remove unnecessary masks
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 6000ece commit af1fd37

File tree

2 files changed

+258
-15
lines changed

2 files changed

+258
-15
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: triton-opt %s -triton-intel-remove-masks | FileCheck %s
2+
3+
module {
4+
tt.func public @test1(%in_ptr0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %in_ptr1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
5+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16>
6+
%c576_i32 = arith.constant 576 : i32
7+
%c0_i32 = arith.constant 0 : i32
8+
%cst_4 = arith.constant dense<9216> : tensor<32x1xi32>
9+
%cst_5 = arith.constant dense<16> : tensor<1x32xi32>
10+
%cst_6 = arith.constant dense<576> : tensor<1x32xi32>
11+
%cst_7 = arith.constant dense<0xFF800000> : tensor<32x32xf32>
12+
%cst_8 = arith.constant dense<16> : tensor<32x1xi32>
13+
%c32_i32 = arith.constant 32 : i32
14+
%0 = tt.get_program_id x : i32
15+
%1 = arith.muli %0, %c32_i32 : i32
16+
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
17+
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32>
18+
%4 = tt.splat %1 : i32 -> tensor<32x1xi32>
19+
%5 = arith.addi %4, %3 : tensor<32x1xi32>
20+
%6 = tt.expand_dims %2 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
21+
%7 = arith.remsi %5, %cst_8 : tensor<32x1xi32>
22+
%8 = arith.divsi %5, %cst_8 : tensor<32x1xi32>
23+
%9 = tt.splat %in_ptr1 : !tt.ptr<f16> -> tensor<32x1x!tt.ptr<f16>>
24+
%10 = tt.addptr %9, %7 : tensor<32x1x!tt.ptr<f16>>, tensor<32x1xi32>
25+
%11 = tt.load %10 evictionPolicy = evict_last : tensor<32x1x!tt.ptr<f16>>
26+
%12 = arith.extf %11 : tensor<32x1xf16> to tensor<32x1xf32>
27+
%13 = tt.broadcast %7 : tensor<32x1xi32> -> tensor<32x32xi32>
28+
%14 = arith.muli %8, %cst_4 : tensor<32x1xi32>
29+
%15 = tt.broadcast %14 : tensor<32x1xi32> -> tensor<32x32xi32>
30+
%16 = tt.splat %in_ptr0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>>
31+
%17 = tt.broadcast %12 : tensor<32x1xf32> -> tensor<32x32xf32>
32+
%_tmp5 = scf.for %r0_offset = %c0_i32 to %c576_i32 step %c32_i32 iter_args(%_tmp5_9 = %cst_7) -> (tensor<32x32xf32>) : i32 {
33+
%44 = tt.splat %r0_offset : i32 -> tensor<1x32xi32>
34+
%45 = arith.addi %44, %6 : tensor<1x32xi32>
35+
%46 = arith.cmpi slt, %45, %cst_6 : tensor<1x32xi32>
36+
%47 = arith.muli %45, %cst_5 : tensor<1x32xi32>
37+
%48 = tt.broadcast %47 : tensor<1x32xi32> -> tensor<32x32xi32>
38+
%49 = arith.addi %13, %48 : tensor<32x32xi32>
39+
%50 = arith.addi %49, %15 : tensor<32x32xi32>
40+
%51 = tt.addptr %16, %50 : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32>
41+
%52 = tt.broadcast %46 : tensor<1x32xi1> -> tensor<32x32xi1>
42+
%53 = tt.load %51, %52, %cst evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
43+
%54 = arith.extf %53 : tensor<32x32xf16> to tensor<32x32xf32>
44+
%55 = arith.addf %54, %17 : tensor<32x32xf32>
45+
%mask = arith.cmpf ogt, %_tmp5_9, %55 : tensor<32x32xf32>
46+
%56 = arith.cmpf une, %_tmp5_9, %_tmp5_9 : tensor<32x32xf32>
47+
%mask_10 = arith.ori %mask, %56 : tensor<32x32xi1>
48+
%57 = arith.select %mask_10, %_tmp5_9, %55 : tensor<32x32xi1>, tensor<32x32xf32>
49+
%58 = arith.select %52, %57, %_tmp5_9 : tensor<32x32xi1>, tensor<32x32xf32>
50+
scf.yield %58 : tensor<32x32xf32>
51+
}
52+
tt.return
53+
}
54+
// CHECK: tt.func public @test1([[PARAM_0_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
55+
// CHECK: scf.for
56+
// CHECK: [[PTR:%.+]] = tt.addptr {{.*}} : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32>
57+
// CHECK: [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
58+
// CHECK: arith.extf [[LOAD]] : tensor<32x32xf16> to tensor<32x32xf32>
59+
// CHECK: [[ORI:%.+]] = arith.ori {{.*}} : tensor<32x32xi1>
60+
// CHECK: [[SEL:%.+]] = arith.select [[ORI]], {{.*}}, {{.*}} : tensor<32x32xi1>, tensor<32x32xf32>
61+
// CHECK: scf.yield [[SEL]] : tensor<32x32xf32>
62+
// CHECK: }
63+
}

third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp

Lines changed: 195 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22
#include "intel/include/Utils/Utility.h"
33
#include "mlir/Dialect/Arith/IR/Arith.h"
44
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
6+
#include "mlir/IR/OpDefinition.h"
57
#include "mlir/IR/Verifier.h"
8+
#include "mlir/Interfaces/InferIntRangeInterface.h"
9+
#include "mlir/Support/LLVM.h"
610
#include "triton/Dialect/Triton/IR/Dialect.h"
11+
#include "llvm/ADT/TypeSwitch.h"
12+
#include "llvm/IR/Instructions.h"
713
#include "llvm/Support/Debug.h"
814
#include "llvm/Support/raw_ostream.h"
15+
#include <optional>
916

1017
#define DEBUG_TYPE "triton-intel-remove-masks"
1118

@@ -19,12 +26,41 @@ namespace mlir::triton::intel {
1926

2027
namespace {
2128

29+
static Operation *dropMask(Operation *op, bool maskVal) {
30+
assert(op && "Expecting a valid operation");
31+
32+
OpBuilder builder(op);
33+
Location loc = op->getLoc();
34+
TypeSwitch<Operation *>(op)
35+
.Case<tt::LoadOp>([&](auto loadOp) {
36+
if (maskVal) {
37+
tt::LoadOp newLoadOp = builder.create<tt::LoadOp>(
38+
loc, loadOp.getPtr(), loadOp.getCache(), loadOp.getEvict(),
39+
loadOp.getIsVolatile());
40+
loadOp->replaceAllUsesWith(newLoadOp);
41+
} else {
42+
Value other = loadOp.getOther();
43+
Operation *cstOp = builder.create<arith::ConstantOp>(loc, other);
44+
loadOp->replaceAllUsesWith(cstOp);
45+
}
46+
})
47+
.Case<arith::SelectOp>([&](auto selectOp) {
48+
selectOp->replaceAllUsesWith(
49+
(maskVal ? selectOp.getTrueValue() : selectOp.getFalseValue())
50+
.getDefiningOp());
51+
})
52+
.Default([](auto) {
53+
return nullptr;
54+
llvm_unreachable("Unexpected operation");
55+
});
56+
57+
return nullptr;
58+
}
59+
2260
// Abstract base class for mask validators.
2361
// Mask validators are used to check whether a given mask has an expected form.
2462
// Concrete subclasses provide a member function used to select masked
2563
// operations that have a mask in a particular (e.g. desired) form.
26-
// Furthermore concrete mask validators classes might also provide a member
27-
// function
2864
class MaskValidatorBase {
2965
public:
3066
virtual ~MaskValidatorBase() = default;
@@ -38,6 +74,112 @@ class MaskValidatorBase {
3874
virtual std::string getName() const = 0;
3975
};
4076

77+
// A mask validator which ensures the mask is not necessary.
78+
class RemovableMaskValidator final : public MaskValidatorBase {
79+
public:
80+
virtual bool isValidMask(scf::ForOp &forOp, Value mask) const {
81+
Value finalVal = tt::intel::getFinalValue(mask);
82+
assert(finalVal && "Expecting a valid mask");
83+
84+
// Ensure the loop range is known.
85+
std::optional<ConstantIntRanges> optRange = computeLoopIVRange(forOp);
86+
if (!optRange)
87+
return false;
88+
89+
if (!finalVal.getDefiningOp() ||
90+
!isa<arith::CmpIOp>(finalVal.getDefiningOp()))
91+
return false;
92+
93+
auto cmpOp = cast<arith::CmpIOp>(finalVal.getDefiningOp());
94+
arith::CmpIPredicate pred = cmpOp.getPredicate();
95+
if (pred != arith::CmpIPredicate::slt)
96+
return false;
97+
98+
Value lhs = tt::intel::getFinalValue(cmpOp.getLhs());
99+
Value rhs = tt::intel::getFinalValue(cmpOp.getRhs());
100+
Operation *lhsOp = tt::intel::getFinalValue(lhs).getDefiningOp();
101+
Operation *rhsOp = tt::intel::getFinalValue(rhs).getDefiningOp();
102+
if (!lhsOp || !rhsOp)
103+
return false;
104+
105+
auto getIntConstantValue = [](Operation *op) -> std::optional<APInt> {
106+
DenseElementsAttr constAttr;
107+
if (matchPattern(op, m_Constant(&constAttr))) {
108+
auto attr = constAttr.getSplatValue<Attribute>();
109+
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
110+
return intAttr.getValue();
111+
}
112+
return std::nullopt;
113+
};
114+
115+
// TODO: consider the case where the constant is lhs.
116+
std::optional<APInt> constIntVal = getIntConstantValue(rhsOp);
117+
if (!constIntVal)
118+
return false;
119+
120+
if (auto addOp = dyn_cast<arith::AddIOp>(lhsOp)) {
121+
Value lhs = tt::intel::getFinalValue(addOp.getLhs());
122+
Value rhs = tt::intel::getFinalValue(addOp.getRhs());
123+
if (lhs != forOp.getSingleInductionVar())
124+
return false;
125+
if (auto makeRangeOp = dyn_cast<tt::MakeRangeOp>(rhs.getDefiningOp())) {
126+
APInt maxIV = (*optRange).smax();
127+
int64_t maxVal = maxIV.getSExtValue() + makeRangeOp.getEnd();
128+
129+
auto registerMaskValue = [this](Value mask, bool maskVal) {
130+
for (Operation *user : mask.getUsers())
131+
opToMaskValue.insert({user, maskVal});
132+
};
133+
134+
if (maxVal <= constIntVal->getSExtValue()) {
135+
registerMaskValue(mask, true);
136+
return true;
137+
}
138+
APInt minIV = (*optRange).smin();
139+
int64_t minVal = minIV.getSExtValue() + makeRangeOp.getStart();
140+
if (minVal >= constIntVal->getSExtValue()) {
141+
registerMaskValue(mask, false);
142+
return true;
143+
}
144+
}
145+
}
146+
147+
return false;
148+
}
149+
150+
virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const {
151+
return {};
152+
}
153+
154+
virtual std::string getName() const { return "RemovableMaskValidator"; }
155+
156+
bool getMaskValue(Operation *op) const {
157+
assert(opToMaskValue.find(op) != opToMaskValue.end() && "mask not present");
158+
return opToMaskValue[op];
159+
}
160+
161+
private:
162+
mutable std::map<Operation *, bool> opToMaskValue;
163+
164+
std::optional<ConstantIntRanges> computeLoopIVRange(scf::ForOp forOp) const {
165+
if (!forOp.getSingleInductionVar())
166+
return std::nullopt;
167+
168+
if (std::optional<APInt> tripCount = forOp.getStaticTripCount()) {
169+
OpFoldResult lb = *forOp.getSingleLowerBound();
170+
OpFoldResult step = *forOp.getSingleStep();
171+
int64_t lbVal = *getConstantIntValue(lb);
172+
int64_t stepVal = *getConstantIntValue(step);
173+
int64_t lastIVVal = stepVal * (tripCount->getSExtValue() - 1);
174+
llvm::APInt start(64, lbVal, true);
175+
llvm::APInt end(64, lastIVVal, true);
176+
return ConstantIntRanges::range(start, end, true);
177+
}
178+
179+
return std::nullopt;
180+
}
181+
};
182+
41183
// A mask validator which ensures that the mask can be reduced to the form:
42184
// `END-1 < N-i*END`
43185
class CanonicalMaskValidator final : public MaskValidatorBase {
@@ -51,12 +193,14 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
51193

52194
// Check whether the mask is equivalent to the form: `END-1 < N-i*END`.
53195
virtual bool isValidMask(scf::ForOp &forOp, Value mask) const {
54-
assert(mask && "Expecting a valid mask");
196+
Value finalVal = tt::intel::getFinalValue(mask);
197+
assert(finalVal && "Expecting a valid mask");
55198

56-
if (!mask.getDefiningOp() || !isa<arith::CmpIOp>(mask.getDefiningOp()))
199+
if (!finalVal.getDefiningOp() ||
200+
!isa<arith::CmpIOp>(finalVal.getDefiningOp()))
57201
return false;
58202

59-
auto cmpOp = cast<arith::CmpIOp>(mask.getDefiningOp());
203+
auto cmpOp = cast<arith::CmpIOp>(finalVal.getDefiningOp());
60204
arith::CmpIPredicate pred = cmpOp.getPredicate();
61205
if (pred != arith::CmpIPredicate::slt)
62206
return false;
@@ -103,7 +247,10 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
103247
// `(N+END-1)/END` (possibly folded), the versioning condition will be:
104248
// `(N+END-1)%END > 0 && N > END`.
105249
virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const {
106-
MaskInfo maskInfo = getMaskInfo(forOp, mask);
250+
Value finalVal = tt::intel::getFinalValue(mask);
251+
assert(finalVal && "Expecting a valid mask");
252+
253+
MaskInfo maskInfo = getMaskInfo(forOp, finalVal);
107254
if (!hasCanonicalUpperBound(forOp, maskInfo))
108255
return nullptr;
109256

@@ -203,11 +350,14 @@ class InvariantMaskValidator final : public MaskValidatorBase {
203350
// - [0..END] < splat(N)
204351
// - splat(N) < [0..END]
205352
virtual bool isValidMask(scf::ForOp &forOp, Value mask) const {
206-
assert(mask && "Expecting a valid mask");
207-
if (!mask.getDefiningOp() || !isa<arith::CmpIOp>(mask.getDefiningOp()))
353+
Value finalVal = tt::intel::getFinalValue(mask);
354+
assert(finalVal && "Expecting a valid mask");
355+
356+
if (!finalVal.getDefiningOp() ||
357+
!isa<arith::CmpIOp>(finalVal.getDefiningOp()))
208358
return false;
209359

210-
auto cmpOp = cast<arith::CmpIOp>(mask.getDefiningOp());
360+
auto cmpOp = cast<arith::CmpIOp>(finalVal.getDefiningOp());
211361
arith::CmpIPredicate pred = cmpOp.getPredicate();
212362
if (pred != arith::CmpIPredicate::slt)
213363
return false;
@@ -242,7 +392,7 @@ class InvariantMaskValidator final : public MaskValidatorBase {
242392
}
243393

244394
return false;
245-
}
395+
} // namespace
246396

247397
virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const {
248398
assert(isValidMask(forOp, mask) && "Invalid mask");
@@ -301,11 +451,8 @@ template <typename MaskValidator> class MaskedOpsCollector {
301451
bool collectMaskedOps() {
302452
auto collectMaskedOps = [&](auto ops, MaskedOperations &maskedOps) {
303453
for (Operation *op : ops) {
304-
Value mask = isa<tt::LoadOp>(op) ? cast<tt::LoadOp>(op).getMask()
305-
: isa<tt::StoreOp>(op) ? cast<tt::StoreOp>(op).getMask()
306-
: nullptr;
307-
if (mask &&
308-
maskValidator.isValidMask(forOp, tt::intel::getFinalValue(mask))) {
454+
Value mask = getMask(op);
455+
if (mask && maskValidator.isValidMask(forOp, mask)) {
309456
maskedOps.insert(op);
310457
LLVM_DEBUG(llvm::dbgs()
311458
<< maskValidator.getName()
@@ -316,12 +463,23 @@ template <typename MaskValidator> class MaskedOpsCollector {
316463

317464
collectMaskedOps(forOp.getOps<tt::LoadOp>(), maskedOps);
318465
collectMaskedOps(forOp.getOps<tt::StoreOp>(), maskedOps);
466+
collectMaskedOps(forOp.getOps<arith::SelectOp>(), maskedOps);
319467
return maskedOps.size();
320468
}
321469

322470
const MaskedOperations &getMaskedOps() const { return maskedOps; };
323471
const MaskValidator &getMaskValidator() const { return maskValidator; }
324472

473+
Value getMask(Operation *op) const {
474+
assert(op && "Expecting a valid operation");
475+
return TypeSwitch<Operation *, Value>(op)
476+
.Case<tt::LoadOp, tt::StoreOp>(
477+
[](auto maskedOp) { return maskedOp.getMask(); })
478+
.template Case<arith::SelectOp>(
479+
[](auto selectOp) { return selectOp.getCondition(); })
480+
.Default([](auto) { return nullptr; });
481+
}
482+
325483
private:
326484
scf::ForOp &forOp;
327485
MaskValidator &maskValidator;
@@ -515,6 +673,28 @@ struct TritonIntelRemoveMasksBase
515673
void runOnOperation() final {
516674
ModuleOp moduleOp = getOperation();
517675

676+
// Remove masks if the are not necessary
677+
moduleOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
678+
if (scf::ForOp forOp = dyn_cast<scf::ForOp>(op)) {
679+
// Nested loop aren't currently handled.
680+
if (forOp->template getParentOfType<scf::ForOp>())
681+
return WalkResult::advance();
682+
683+
if (!forOp.getSingleInductionVar())
684+
return WalkResult::advance();
685+
686+
RemovableMaskValidator maskValidator;
687+
MaskedOpsCollector collector(forOp, maskValidator);
688+
if (collector.collectMaskedOps()) {
689+
for (Operation *op : collector.getMaskedOps()) {
690+
bool maskVal = maskValidator.getMaskValue(op);
691+
dropMask(op, maskVal);
692+
}
693+
}
694+
}
695+
return WalkResult::advance();
696+
});
697+
518698
// Version loops containing masked operation in canonical form.
519699
moduleOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
520700
if (scf::ForOp forOp = dyn_cast<scf::ForOp>(op)) {

0 commit comments

Comments
 (0)