Skip to content

Commit 3871595

Browse files
authored
Revert "Revert " Reapply "[Layouts] Propagate layouts into conditionals (#5610)" (#5725)"" (#3347)
This reverts commit 942cf94. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent b7d1eaa commit 3871595

File tree

5 files changed

+588
-10
lines changed

5 files changed

+588
-10
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class LayoutRematerialization {
135135
void hoistConvertDotOperand(ConvertLayoutOp convertOp);
136136
void hoistConvertOnTopOfExtOrBroadcast();
137137
void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp);
138+
void hoistConvertIntoConditionals();
139+
void hoistConvertIntoConditionals(ConvertLayoutOp convertOp);
138140
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
139141
ConvertLayoutOp convertOp, IRMapping &mapping);
140142
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
@@ -1042,6 +1044,22 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
10421044
}
10431045
}
10441046

1047+
void LayoutRematerialization::hoistConvertIntoConditionals() {
1048+
// Go through each ConvertLayoutOp.
1049+
SmallVector<ConvertLayoutOp> convertOps;
1050+
funcOp.walk(
1051+
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
1052+
for (ConvertLayoutOp convertOp : convertOps) {
1053+
hoistConvertIntoConditionals(convertOp);
1054+
if (!opToDelete.contains(convertOp)) {
1055+
// If the conversion didn't get removed, consider it for reuse in future
1056+
// backward slices.
1057+
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
1058+
convertOp.getResult());
1059+
}
1060+
}
1061+
}
1062+
10451063
void LayoutRematerialization::backwardRematerialization(
10461064
ConvertLayoutOp convertOp) {
10471065
// DotOperand is hoisted by hoistDotOperand
@@ -1268,6 +1286,155 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
12681286
rewriteSlice(slice, layout, convertOp, mapping);
12691287
}
12701288

1289+
void LayoutRematerialization::hoistConvertIntoConditionals(
1290+
ConvertLayoutOp convertOp) {
1291+
// Take the backward slice of tensor dependencies rooted at the conversion,
1292+
// stopping at conditionals. This subslice is used to initialize the analysis.
1293+
SetVector<Value> slice;
1294+
DenseMap<Value, Attribute> layout;
1295+
auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
1296+
if (failed(getRematerializableSlice(convertOp.getSrcMutable(),
1297+
convertOp.getType().getEncoding(), slice,
1298+
layout, isIfOp)))
1299+
return;
1300+
1301+
// These are the conditional edges above which conversions should be hoisted.
1302+
// The value represents the `scf.if` op result and the operand represents the
1303+
// edge into one of the branches.
1304+
SmallVector<std::pair<OpResult, OpOperand *>> hoistAbove;
1305+
1306+
// The list of `scf.if` op results in the slice that are not rematerializable.
1307+
// Hoisting is terminated at these values.
1308+
SmallVector<OpResult> terminals;
1309+
1310+
// Process the whole backward slice in subslices that stop at each condtional.
1311+
// This is so we can apply more specific rules about when to hoist.
1312+
struct Subslice {
1313+
OpResult v;
1314+
OpOperand *edge;
1315+
SetVector<Value> slice;
1316+
DenseMap<Value, Attribute> layout;
1317+
};
1318+
SmallVector<Subslice> subslices;
1319+
1320+
// Check a value in the subslice.
1321+
auto visitValue = [&](OpResult v) {
1322+
auto ifOp = v.getDefiningOp<scf::IfOp>();
1323+
if (!ifOp)
1324+
return;
1325+
1326+
Attribute rootLayout = layout.at(v);
1327+
unsigned resIdx = cast<OpResult>(v).getResultNumber();
1328+
1329+
// Take the backward slice along each branch.
1330+
auto thenYield =
1331+
cast<scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
1332+
auto elseYield =
1333+
cast<scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
1334+
1335+
OpOperand &thenRes = thenYield.getResultsMutable()[resIdx];
1336+
OpOperand &elseRes = elseYield.getResultsMutable()[resIdx];
1337+
1338+
SetVector<Value> thenSlice, elseSlice;
1339+
DenseMap<Value, Attribute> thenLayout, elseLayout;
1340+
1341+
LogicalResult thenResult = getRematerializableSlice(
1342+
thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
1343+
LogicalResult elseResult = getRematerializableSlice(
1344+
elseRes, rootLayout, elseSlice, elseLayout, isIfOp);
1345+
1346+
// If propagation across both edges of this conditional succeeded, then we
1347+
// don't need to hoist across it. Merge into the current slice.
1348+
if (succeeded(thenResult) && succeeded(elseResult)) {
1349+
slice.insert(thenSlice.begin(), thenSlice.end());
1350+
slice.insert(elseSlice.begin(), elseSlice.end());
1351+
layout.insert(thenLayout.begin(), thenLayout.end());
1352+
layout.insert(elseLayout.begin(), elseLayout.end());
1353+
return;
1354+
}
1355+
1356+
// If propagation across both edges failed, then this conditional
1357+
// terminates backwards rematerialization.
1358+
if (failed(thenResult) && failed(elseResult)) {
1359+
terminals.push_back(v);
1360+
return;
1361+
}
1362+
1363+
// The layout conversion can be rematerialized along one edge but not the
1364+
// other. We can hoist the conversion into the other branch. Push this
1365+
// into the subslice list for analysis.
1366+
if (succeeded(thenResult)) {
1367+
subslices.push_back(
1368+
{v, &elseRes, std::move(thenSlice), std::move(thenLayout)});
1369+
} else {
1370+
subslices.push_back(
1371+
{v, &thenRes, std::move(elseSlice), std::move(elseLayout)});
1372+
}
1373+
};
1374+
1375+
// Process the whole slice in subslices.
1376+
unsigned i = 0;
1377+
bool isLoneHoist = false;
1378+
do {
1379+
// Visit values in the current subslice.
1380+
for (; i != slice.size(); ++i) {
1381+
if (auto v = dyn_cast<OpResult>(slice[i]))
1382+
visitValue(v);
1383+
}
1384+
// Check the next chunk of subslices. When a condtional is marked as being
1385+
// valid to be hoisted across, we have to recurse on a new subslice rooted
1386+
// at the corresopnding yield operand.
1387+
//
1388+
// Hoist across condtionals when:
1389+
// 1. The conditional is directly inside a loop.
1390+
// 2. The whole slice contains only one conditional.
1391+
for (auto &[v, edge, subslice, layouts] : subslices) {
1392+
bool oneHoist = false;
1393+
if (isa<LoopLikeOpInterface>(v.getDefiningOp()->getParentOp()) ||
1394+
(oneHoist = subslices.size() == 1 && hoistAbove.empty())) {
1395+
isLoneHoist |= oneHoist;
1396+
hoistAbove.push_back({v, edge});
1397+
// Recurse on the subslice.
1398+
slice.insert(subslice.begin(), subslice.end());
1399+
layout.insert(layouts.begin(), layouts.end());
1400+
} else {
1401+
terminals.push_back(v);
1402+
}
1403+
}
1404+
subslices.clear();
1405+
} while (i != slice.size());
1406+
1407+
// Exit early if there is nothing to do.
1408+
if (hoistAbove.empty())
1409+
return;
1410+
// Check if this is a lone hoist. There should be no other terminals.
1411+
if (isLoneHoist && !terminals.empty())
1412+
return;
1413+
1414+
// Rematerialize failed hoists right before the condtional, and hoist those
1415+
// that succeeded into the branch and then rewrite the slice.
1416+
IRMapping mapping;
1417+
auto hoistRemat = [&](OpBuilder &b, Value v, Attribute encoding) {
1418+
auto tensorType = cast<RankedTensorType>(v.getType());
1419+
auto newType = RankedTensorType::get(tensorType.getShape(),
1420+
tensorType.getElementType(), encoding);
1421+
Value newCvt = b.create<ConvertLayoutOp>(convertOp.getLoc(), newType, v);
1422+
1423+
mapping.map(v, newCvt);
1424+
slice.remove(v);
1425+
};
1426+
for (Value v : terminals) {
1427+
OpBuilder b(v.getContext());
1428+
b.setInsertionPointAfter(v.getDefiningOp());
1429+
hoistRemat(b, v, layout.at(v));
1430+
}
1431+
for (auto [result, edge] : hoistAbove) {
1432+
OpBuilder b(edge->getOwner());
1433+
hoistRemat(b, edge->get(), layout.at(result));
1434+
}
1435+
rewriteSlice(slice, layout, convertOp, mapping);
1436+
}
1437+
12711438
void backwardRematerialization(ModuleOp module) {
12721439
module.walk([](FuncOp funcOp) {
12731440
LayoutRematerialization layoutRemat(funcOp);
@@ -1283,6 +1450,10 @@ void hoistConvert(ModuleOp module) {
12831450
layoutRemat.hoistConvertOnTopOfExtOrBroadcast();
12841451
layoutRemat.cleanup();
12851452

1453+
layoutRemat = LayoutRematerialization(funcOp);
1454+
layoutRemat.hoistConvertIntoConditionals();
1455+
layoutRemat.cleanup();
1456+
12861457
layoutRemat = LayoutRematerialization(funcOp);
12871458
layoutRemat.hoistConvertDotOperand();
12881459
layoutRemat.cleanup();

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -803,11 +803,10 @@ LogicalResult getConvertBackwardSlice(
803803
auto updateLayout = [&](Value value, Attribute encoding) {
804804
assert((isa<RankedTensorType>(value.getType())));
805805
slice.insert(value);
806-
if (layout.find(value) != layout.end()) {
807-
if (layout[value] != encoding)
808-
return failure();
809-
}
810-
layout[value] = encoding;
806+
Attribute &existing = layout[value];
807+
if (existing && existing != encoding)
808+
return failure();
809+
existing = encoding;
811810
return success();
812811
};
813812

@@ -833,6 +832,8 @@ LogicalResult getConvertBackwardSlice(
833832
}
834833

835834
if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
835+
if (stopPropagation && stopPropagation(ifOp))
836+
continue;
836837
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();
837838

838839
OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx);

0 commit comments

Comments
 (0)