Skip to content

Commit 264d510

Browse files
authored
[NFI]: RemoveLayoutConversion sync (#4679)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 61840ac commit 264d510

File tree

1 file changed

+98
-38
lines changed

1 file changed

+98
-38
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 98 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33
#include "mlir/IR/BuiltinAttributes.h"
44
#include "mlir/IR/Dominance.h"
55
#include "mlir/IR/IRMapping.h"
6+
#include "mlir/IR/Matchers.h"
67
#include "mlir/IR/PatternMatch.h"
78
#include "mlir/IR/Verifier.h"
9+
#include "mlir/Interfaces/InferTypeOpInterface.h"
810
#include "mlir/Interfaces/SideEffectInterfaces.h"
11+
#include "mlir/Pass/Pass.h"
912
#include "mlir/Pass/PassManager.h"
1013
#include "mlir/Support/LogicalResult.h"
1114
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15+
#include "mlir/Transforms/Passes.h"
16+
#include "mlir/Transforms/RegionUtils.h"
1217

1318
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
1419
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
@@ -17,7 +22,9 @@
1722

1823
#include "triton/Analysis/Utility.h"
1924
#include "triton/Dialect/Triton/IR/Dialect.h"
20-
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
25+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
26+
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
27+
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
2128
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
2229
#include <deque>
2330

@@ -106,6 +113,8 @@ class LayoutPropagation {
106113
// Return the mapped value in the given encoding. This will insert a convert
107114
// if the encoding is different than the encoding decided at resolve time.
108115
Value getValueAs(Value value, Attribute encoding);
116+
// Return the original value mapped to the new desired encoding.
117+
Value getRewrittenValue(Value value);
109118
// Dump the current stage of layout information.
110119
void dump();
111120

@@ -190,7 +199,7 @@ bool isLayoutAnchor(Operation *op) {
190199
return ttgi::isExpensiveLoadOrStore(op);
191200
// TODO: we should estimate the cost of the not propagating layout for
192201
// AtomicCAS for further performance consideration.
193-
if (isa<DotOp, AtomicCASOp>(op))
202+
if (isa<DotOp, DotScaledOp, AtomicCASOp>(op))
194203
return true;
195204
if (isa<AtomicRMWOp>(op))
196205
if (auto tensorType =
@@ -304,6 +313,15 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
304313
setEncoding(user->getResults(), info, changed, user);
305314
continue;
306315
}
316+
if (auto gatherOp = dyn_cast<GatherOp>(user)) {
317+
// Propagate the layout through the indices only, and if the layout does
318+
// not have an efficient layout set.
319+
if (!gatherOp.getEfficientLayout() &&
320+
&use == &gatherOp.getIndicesMutable()) {
321+
setEncoding(gatherOp.getResult(), info, changed, user);
322+
continue;
323+
}
324+
}
307325
if (auto storeOp = dyn_cast<StoreOp>(user)) {
308326
auto checkMMAorMMADerived = [](Attribute encoding) {
309327
bool isMMAorMMADerived = isa<MmaEncodingTrait>(encoding);
@@ -339,7 +357,7 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
339357

340358
void LayoutPropagation::propagateLayout() {
341359
SmallVector<Value> queue;
342-
for (const auto &it : layouts) {
360+
for (auto it : layouts) {
343361
queue.push_back(it.first);
344362
}
345363
while (!queue.empty()) {
@@ -353,6 +371,7 @@ void LayoutPropagation::propagateLayout() {
353371
<< info.encodings.size() << " candidate encoding(s):\n";
354372
for (Attribute encoding : info.encodings)
355373
DBGS() << " " << encoding << "\n";
374+
DBGS() << "changed: " << changed.size() << "\n";
356375
});
357376

358377
queue.insert(queue.end(), changed.begin(), changed.end());
@@ -469,22 +488,25 @@ void LayoutPropagation::map(Value old, Value newV) {
469488
newV;
470489
}
471490

491+
Value LayoutPropagation::getRewrittenValue(Value value) {
492+
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
493+
if (!tensorType)
494+
return value;
495+
auto layoutIt = layouts.find(value);
496+
if (layoutIt == layouts.end()) {
497+
return value;
498+
}
499+
assert(layoutIt->second.encodings.size() == 1 &&
500+
"we should have resolved to a single encoding");
501+
Attribute encodingPicked = *(layoutIt->second.encodings.begin());
502+
if (encodingPicked == tensorType.getEncoding())
503+
return value;
504+
return rewriteMapping.at({value, encodingPicked});
505+
}
506+
472507
Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
473508
if (auto tensorType = dyn_cast<RankedTensorType>(value.getType())) {
474-
Value rewrittenValue;
475-
auto layoutIt = layouts.find(value);
476-
if (layoutIt == layouts.end()) {
477-
rewrittenValue = value;
478-
} else {
479-
assert(layoutIt->second.encodings.size() == 1 &&
480-
"we should have resolved to a single encoding");
481-
Attribute encodingPicked = *(layoutIt->second.encodings.begin());
482-
if (encodingPicked == tensorType.getEncoding())
483-
rewrittenValue = value;
484-
else
485-
rewrittenValue = rewriteMapping[{value, encodingPicked}];
486-
}
487-
assert(rewrittenValue);
509+
Value rewrittenValue = getRewrittenValue(value);
488510
if (cast<RankedTensorType>(rewrittenValue.getType()).getEncoding() ==
489511
encoding)
490512
return rewrittenValue;
@@ -922,7 +944,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
922944
}
923945
if (op->hasTrait<OpTrait::SameOperandsAndResultEncoding>() ||
924946
op->hasTrait<OpTrait::Elementwise>() ||
925-
isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp,
947+
isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp, GatherOp,
926948
ConvertLayoutOp>(op)) {
927949
Operation *newOp = cloneElementwise(rewriter, op, encoding);
928950
for (auto [oldResult, newResult] :
@@ -944,6 +966,9 @@ bool canBeRemat(Operation *op) {
944966
return !ttgi::isExpensiveLoadOrStore(op);
945967
if (isa<AtomicRMWOp, AtomicCASOp, DotOp>(op))
946968
return false;
969+
if (auto gather = dyn_cast<GatherOp>(op))
970+
return !gather.getEfficientLayout();
971+
947972
if (isa<scf::WhileOp, scf::ConditionOp>(op))
948973
return false;
949974

@@ -1211,8 +1236,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
12111236
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
12121237
DenseMap<Value, Attribute> &layout,
12131238
std::function<bool(Operation *)> stopPropagation) {
1214-
LogicalResult result = getConvertBackwardSlice(
1215-
root, rootEncoding, slice, layout, std::move(stopPropagation));
1239+
LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice,
1240+
layout, stopPropagation);
12161241
if (result.failed() || slice.empty())
12171242
return failure();
12181243

@@ -1226,13 +1251,13 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
12261251
return success();
12271252
}
12281253

1229-
void LayoutRematerialization::hoistConvertIntoConditionals() {
1254+
void LayoutRematerialization::backwardRematerialization() {
12301255
// Go through each ConvertLayoutOp.
12311256
SmallVector<ConvertLayoutOp> convertOps;
12321257
funcOp.walk(
12331258
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
12341259
for (ConvertLayoutOp convertOp : convertOps) {
1235-
hoistConvertIntoConditionals(convertOp);
1260+
backwardRematerialization(convertOp);
12361261
if (!opToDelete.contains(convertOp)) {
12371262
// If the conversion didn't get removed, consider it for reuse in future
12381263
// backward slices.
@@ -1242,13 +1267,13 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
12421267
}
12431268
}
12441269

1245-
void LayoutRematerialization::backwardRematerialization() {
1270+
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
12461271
// Go through each ConvertLayoutOp.
12471272
SmallVector<ConvertLayoutOp> convertOps;
12481273
funcOp.walk(
12491274
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
12501275
for (ConvertLayoutOp convertOp : convertOps) {
1251-
backwardRematerialization(convertOp);
1276+
hoistConvertOnTopOfExtOrBroadcast(convertOp);
12521277
if (!opToDelete.contains(convertOp)) {
12531278
// If the conversion didn't get removed, consider it for reuse in future
12541279
// backward slices.
@@ -1258,13 +1283,13 @@ void LayoutRematerialization::backwardRematerialization() {
12581283
}
12591284
}
12601285

1261-
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
1286+
void LayoutRematerialization::hoistConvertIntoConditionals() {
12621287
// Go through each ConvertLayoutOp.
12631288
SmallVector<ConvertLayoutOp> convertOps;
12641289
funcOp.walk(
12651290
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
12661291
for (ConvertLayoutOp convertOp : convertOps) {
1267-
hoistConvertOnTopOfExtOrBroadcast(convertOp);
1292+
hoistConvertIntoConditionals(convertOp);
12681293
if (!opToDelete.contains(convertOp)) {
12691294
// If the conversion didn't get removed, consider it for reuse in future
12701295
// backward slices.
@@ -1274,6 +1299,40 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
12741299
}
12751300
}
12761301

1302+
static bool isExpensiveMathOp(Operation *op) {
1303+
// These operations are either multiple instructions or have throughput
1304+
// lower than 16 according to the arithmetic instructions table in:
1305+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions
1306+
return isa<arith::DivFOp, math::ErfcOp, math::SinhOp, math::CoshOp,
1307+
math::TanhOp, math::AsinhOp, math::AcoshOp, math::AtanhOp,
1308+
math::CtPopOp, math::CountLeadingZerosOp,
1309+
math::CountTrailingZerosOp, math::ExpOp, math::Exp2Op,
1310+
math::ExpM1Op, math::LogOp, math::Log2Op, math::Log10Op,
1311+
math::Log1pOp, math::SinOp, math::CosOp, math::TanOp, math::AsinOp,
1312+
math::AcosOp, math::AtanOp, math::Atan2Op, math::PowFOp,
1313+
math::SqrtOp, math::RsqrtOp, math::ErfOp, math::CbrtOp>(op);
1314+
}
1315+
1316+
static int64_t getByteCount(Value result, int64_t minElementCount = 0,
1317+
int64_t minBitWidth = 0) {
1318+
int64_t elementCount = 0;
1319+
int64_t dtypeBitWidth = 0;
1320+
if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType())) {
1321+
elementCount = tensorTy.getNumElements();
1322+
auto elemType = tensorTy.getElementType();
1323+
if (elemType.isIntOrFloat()) {
1324+
dtypeBitWidth = elemType.getIntOrFloatBitWidth();
1325+
}
1326+
}
1327+
if (elementCount < minElementCount) {
1328+
elementCount = minElementCount;
1329+
}
1330+
if (dtypeBitWidth < minBitWidth) {
1331+
dtypeBitWidth = minBitWidth;
1332+
}
1333+
return (elementCount * dtypeBitWidth) >> 3;
1334+
}
1335+
12771336
void LayoutRematerialization::backwardRematerialization(
12781337
ConvertLayoutOp convertOp) {
12791338
RankedTensorType targetType = convertOp.getType();
@@ -1373,30 +1432,32 @@ void LayoutRematerialization::hoistConvertDotOperand(
13731432
{ DBGS() << " Block arguments not supported. Got " << v << "\n"; });
13741433
return;
13751434
}
1376-
auto loadOp = dyn_cast<LoadOp>(v.getDefiningOp());
1377-
// We expect the leaves of the slice to be Load or arith::Constant
1378-
// This could be generalised if necessary
1379-
if (!loadOp) {
1435+
1436+
// We expect the leaves of the slice to be Load, DescriptorLoad or
1437+
// arith::Constant This could be generalised if necessary
1438+
if (!isa<LoadOp, DescriptorLoadOp>(v.getDefiningOp())) {
13801439
auto op = v.getDefiningOp();
13811440
if (isa<arith::ConstantOp>(op) || noDataMovement(op)) {
13821441
innerSlice.insert(v);
13831442
continue;
13841443
} else {
13851444
LLVM_DEBUG({
1386-
DBGS() << " Leaves must be Load or Constant. Got " << v << "\n";
1445+
DBGS() << " Leaves must be Load, DescriptorLoad or Constant. Got "
1446+
<< v << "\n";
13871447
});
13881448
return;
13891449
}
13901450
}
1451+
Operation *loadOp = v.getDefiningOp();
13911452
builder.setInsertionPointAfter(loadOp);
1392-
auto type = dyn_cast<RankedTensorType>(loadOp.getType());
1453+
auto type = dyn_cast<RankedTensorType>(loadOp->getResult(0).getType());
13931454
if (!type)
13941455
continue;
13951456
auto newType = RankedTensorType::get(type.getShape(), type.getElementType(),
1396-
layout[loadOp]);
1457+
layout[loadOp->getResult(0)]);
13971458
auto newConvertOp = builder.create<ConvertLayoutOp>(
1398-
convertOp.getLoc(), newType, loadOp.getResult());
1399-
mapping.map(loadOp.getResult(), newConvertOp.getResult());
1459+
convertOp.getLoc(), newType, loadOp->getResult(0));
1460+
mapping.map(loadOp->getResult(0), newConvertOp.getResult());
14001461
}
14011462

14021463
if (innerSlice.empty()) {
@@ -1418,7 +1479,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14181479
ConvertLayoutOp convertOp) {
14191480
// DotOperand is hoisted by hoistDotOperand
14201481
RankedTensorType targetType = convertOp.getType();
1421-
if (mlir::isa<DotOperandEncodingAttr>(targetType.getEncoding()))
1482+
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
14221483
return;
14231484

14241485
auto isExtOrBroadcastOp = [](Operation *op) {
@@ -1641,6 +1702,7 @@ void hoistConvert(ModuleOp module) {
16411702
layoutRemat.cleanup();
16421703
});
16431704
}
1705+
} // namespace
16441706

16451707
class TritonIntelGPURemoveLayoutConversionsPass
16461708
: public triton::gpu::intel::impl::
@@ -1722,5 +1784,3 @@ class TritonIntelGPURemoveLayoutConversionsPass
17221784
});
17231785
}
17241786
};
1725-
1726-
} // namespace

0 commit comments

Comments
 (0)