@@ -1458,27 +1458,19 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
14581458 // These are the conditional edges above which conversions should be hoisted.
14591459 // The value represents the `scf.if` op result and the operand represents the
14601460 // edge into one of the branches.
1461- SmallVector<std::pair<OpResult , OpOperand *>> hoistAbove;
1461+ SmallVector<std::pair<Value , OpOperand *>> hoistAbove;
14621462
14631463 // The list of `scf.if` op results in the slice that are not rematerializable.
14641464 // Hoisting is terminated at these values.
14651465 SmallVector<OpResult> terminals;
14661466
1467- // Process the whole backward slice in subslices that stop at each condtional.
1468- // This is so we can apply more specific rules about when to hoist.
1469- struct Subslice {
1470- OpResult v;
1471- OpOperand *edge;
1472- SetVector<Value> slice;
1473- DenseMap<Value, Attribute> layout;
1474- };
1475- SmallVector<Subslice> subslices;
1476-
1477- // Check a value in the subslice.
1478- auto visitValue = [&](OpResult v) {
1467+ // This loop recurses through the subslices of the backwards dependencies, so
1468+ // re-query the size of `slice`.
1469+ for (unsigned i = 0 ; i != slice.size (); ++i) {
1470+ Value v = slice[i];
14791471 auto ifOp = v.getDefiningOp <scf::IfOp>();
14801472 if (!ifOp)
1481- return ;
1473+ continue ;
14821474
14831475 Attribute rootLayout = layout.at (v);
14841476 unsigned resIdx = cast<OpResult>(v).getResultNumber ();
@@ -1507,66 +1499,41 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15071499 slice.insert (elseSlice.begin (), elseSlice.end ());
15081500 layout.insert (thenLayout.begin (), thenLayout.end ());
15091501 layout.insert (elseLayout.begin (), elseLayout.end ());
1510- return ;
1502+ continue ;
15111503 }
15121504
15131505 // If propagation across both edges failed, then this conditional
15141506 // terminates backwards rematerialization.
15151507 if (failed (thenResult) && failed (elseResult)) {
1516- terminals.push_back (v);
1517- return ;
1508+ terminals.push_back (cast<OpResult>(v));
1509+ continue ;
1510+ }
1511+
1512+ // Only hoist into conditionals inside loops. The assumption is that an if
1513+ // inside a loop executes fewer than the total number of loop iterations,
1514+ // making this hoist profitable.
1515+ if (!isa<scf::ForOp>(ifOp->getParentOp ())) {
1516+ terminals.push_back (cast<OpResult>(v));
1517+ continue ;
15181518 }
15191519
15201520 // The layout conversion can be rematerialized along one edge but not the
15211521 // other. We can hoist the conversion into the other branch. Push this
15221522 // into the subslice list for analysis.
15231523 if (succeeded (thenResult)) {
1524- subslices.push_back (
1525- {v, &elseRes, std::move (thenSlice), std::move (thenLayout)});
1524+ hoistAbove.emplace_back (v, &elseRes);
1525+ slice.insert (thenSlice.begin (), thenSlice.end ());
1526+ layout.insert (thenLayout.begin (), thenLayout.end ());
15261527 } else {
1527- subslices.push_back (
1528- {v, &thenRes, std::move (elseSlice), std::move (elseLayout)});
1529- }
1530- };
1531-
1532- // Process the whole slice in subslices.
1533- unsigned i = 0 ;
1534- bool isLoneHoist = false ;
1535- do {
1536- // Visit values in the current subslice.
1537- for (; i != slice.size (); ++i) {
1538- if (auto v = dyn_cast<OpResult>(slice[i]))
1539- visitValue (v);
1540- }
1541- // Check the next chunk of subslices. When a condtional is marked as being
1542- // valid to be hoisted across, we have to recurse on a new subslice rooted
1543- // at the corresopnding yield operand.
1544- //
1545- // Hoist across condtionals when:
1546- // 1. The conditional is directly inside a loop.
1547- // 2. The whole slice contains only one conditional.
1548- for (auto &[v, edge, subslice, layouts] : subslices) {
1549- bool oneHoist = false ;
1550- if (isa<LoopLikeOpInterface>(v.getDefiningOp ()->getParentOp ()) ||
1551- (oneHoist = subslices.size () == 1 && hoistAbove.empty ())) {
1552- isLoneHoist |= oneHoist;
1553- hoistAbove.push_back ({v, edge});
1554- // Recurse on the subslice.
1555- slice.insert (subslice.begin (), subslice.end ());
1556- layout.insert (layouts.begin (), layouts.end ());
1557- } else {
1558- terminals.push_back (v);
1559- }
1528+ hoistAbove.emplace_back (v, &thenRes);
1529+ slice.insert (elseSlice.begin (), elseSlice.end ());
1530+ layout.insert (elseLayout.begin (), elseLayout.end ());
15601531 }
1561- subslices.clear ();
1562- } while (i != slice.size ());
1532+ }
15631533
15641534 // Exit early if there is nothing to do.
15651535 if (hoistAbove.empty ())
15661536 return ;
1567- // Check if this is a lone hoist. There should be no other terminals.
1568- if (isLoneHoist && !terminals.empty ())
1569- return ;
15701537
15711538 // Rematerialize failed hoists right before the condtional, and hoist those
15721539 // that succeeded into the branch and then rewrite the slice.
0 commit comments