Skip to content

Commit 278a05b

Browse files
committed
[intel] align with changes from 032fa41
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 2dc8c3f commit 278a05b

File tree

1 file changed

+24
-57
lines changed

1 file changed

+24
-57
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 24 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,27 +1458,19 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
14581458
// These are the conditional edges above which conversions should be hoisted.
14591459
// The value represents the `scf.if` op result and the operand represents the
14601460
// edge into one of the branches.
1461-
SmallVector<std::pair<OpResult, OpOperand *>> hoistAbove;
1461+
SmallVector<std::pair<Value, OpOperand *>> hoistAbove;
14621462

14631463
// The list of `scf.if` op results in the slice that are not rematerializable.
14641464
// Hoisting is terminated at these values.
14651465
SmallVector<OpResult> terminals;
14661466

1467-
// Process the whole backward slice in subslices that stop at each condtional.
1468-
// This is so we can apply more specific rules about when to hoist.
1469-
struct Subslice {
1470-
OpResult v;
1471-
OpOperand *edge;
1472-
SetVector<Value> slice;
1473-
DenseMap<Value, Attribute> layout;
1474-
};
1475-
SmallVector<Subslice> subslices;
1476-
1477-
// Check a value in the subslice.
1478-
auto visitValue = [&](OpResult v) {
1467+
// This loop recurses through the subslices of the backwards dependencies, so
1468+
// re-query the size of `slice`.
1469+
for (unsigned i = 0; i != slice.size(); ++i) {
1470+
Value v = slice[i];
14791471
auto ifOp = v.getDefiningOp<scf::IfOp>();
14801472
if (!ifOp)
1481-
return;
1473+
continue;
14821474

14831475
Attribute rootLayout = layout.at(v);
14841476
unsigned resIdx = cast<OpResult>(v).getResultNumber();
@@ -1507,66 +1499,41 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15071499
slice.insert(elseSlice.begin(), elseSlice.end());
15081500
layout.insert(thenLayout.begin(), thenLayout.end());
15091501
layout.insert(elseLayout.begin(), elseLayout.end());
1510-
return;
1502+
continue;
15111503
}
15121504

15131505
// If propagation across both edges failed, then this conditional
15141506
// terminates backwards rematerialization.
15151507
if (failed(thenResult) && failed(elseResult)) {
1516-
terminals.push_back(v);
1517-
return;
1508+
terminals.push_back(cast<OpResult>(v));
1509+
continue;
1510+
}
1511+
1512+
// Only hoist into conditionals inside loops. The assumption is that an if
1513+
// inside a loop executes fewer than the total number of loop iterations,
1514+
// making this hoist profitable.
1515+
if (!isa<scf::ForOp>(ifOp->getParentOp())) {
1516+
terminals.push_back(cast<OpResult>(v));
1517+
continue;
15181518
}
15191519

15201520
// The layout conversion can be rematerialized along one edge but not the
15211521
// other. We can hoist the conversion into the other branch. Push this
15221522
// into the subslice list for analysis.
15231523
if (succeeded(thenResult)) {
1524-
subslices.push_back(
1525-
{v, &elseRes, std::move(thenSlice), std::move(thenLayout)});
1524+
hoistAbove.emplace_back(v, &elseRes);
1525+
slice.insert(thenSlice.begin(), thenSlice.end());
1526+
layout.insert(thenLayout.begin(), thenLayout.end());
15261527
} else {
1527-
subslices.push_back(
1528-
{v, &thenRes, std::move(elseSlice), std::move(elseLayout)});
1529-
}
1530-
};
1531-
1532-
// Process the whole slice in subslices.
1533-
unsigned i = 0;
1534-
bool isLoneHoist = false;
1535-
do {
1536-
// Visit values in the current subslice.
1537-
for (; i != slice.size(); ++i) {
1538-
if (auto v = dyn_cast<OpResult>(slice[i]))
1539-
visitValue(v);
1540-
}
1541-
// Check the next chunk of subslices. When a condtional is marked as being
1542-
// valid to be hoisted across, we have to recurse on a new subslice rooted
1543-
// at the corresopnding yield operand.
1544-
//
1545-
// Hoist across condtionals when:
1546-
// 1. The conditional is directly inside a loop.
1547-
// 2. The whole slice contains only one conditional.
1548-
for (auto &[v, edge, subslice, layouts] : subslices) {
1549-
bool oneHoist = false;
1550-
if (isa<LoopLikeOpInterface>(v.getDefiningOp()->getParentOp()) ||
1551-
(oneHoist = subslices.size() == 1 && hoistAbove.empty())) {
1552-
isLoneHoist |= oneHoist;
1553-
hoistAbove.push_back({v, edge});
1554-
// Recurse on the subslice.
1555-
slice.insert(subslice.begin(), subslice.end());
1556-
layout.insert(layouts.begin(), layouts.end());
1557-
} else {
1558-
terminals.push_back(v);
1559-
}
1528+
hoistAbove.emplace_back(v, &thenRes);
1529+
slice.insert(elseSlice.begin(), elseSlice.end());
1530+
layout.insert(elseLayout.begin(), elseLayout.end());
15601531
}
1561-
subslices.clear();
1562-
} while (i != slice.size());
1532+
}
15631533

15641534
// Exit early if there is nothing to do.
15651535
if (hoistAbove.empty())
15661536
return;
1567-
// Check if this is a lone hoist. There should be no other terminals.
1568-
if (isLoneHoist && !terminals.empty())
1569-
return;
15701537

15711538
// Rematerialize failed hoists right before the condtional, and hoist those
15721539
// that succeeded into the branch and then rewrite the slice.

0 commit comments

Comments
 (0)