@@ -1458,27 +1458,19 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
1458
1458
// These are the conditional edges above which conversions should be hoisted.
1459
1459
// The value represents the `scf.if` op result and the operand represents the
1460
1460
// edge into one of the branches.
1461
- SmallVector<std::pair<OpResult , OpOperand *>> hoistAbove;
1461
+ SmallVector<std::pair<Value , OpOperand *>> hoistAbove;
1462
1462
1463
1463
// The list of `scf.if` op results in the slice that are not rematerializable.
1464
1464
// Hoisting is terminated at these values.
1465
1465
SmallVector<OpResult> terminals;
1466
1466
1467
- // Process the whole backward slice in subslices that stop at each condtional.
1468
- // This is so we can apply more specific rules about when to hoist.
1469
- struct Subslice {
1470
- OpResult v;
1471
- OpOperand *edge;
1472
- SetVector<Value> slice;
1473
- DenseMap<Value, Attribute> layout;
1474
- };
1475
- SmallVector<Subslice> subslices;
1476
-
1477
- // Check a value in the subslice.
1478
- auto visitValue = [&](OpResult v) {
1467
+ // This loop recurses through the subslices of the backwards dependencies, so
1468
+ // re-query the size of `slice`.
1469
+ for (unsigned i = 0 ; i != slice.size (); ++i) {
1470
+ Value v = slice[i];
1479
1471
auto ifOp = v.getDefiningOp <scf::IfOp>();
1480
1472
if (!ifOp)
1481
- return ;
1473
+ continue ;
1482
1474
1483
1475
Attribute rootLayout = layout.at (v);
1484
1476
unsigned resIdx = cast<OpResult>(v).getResultNumber ();
@@ -1507,66 +1499,41 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
1507
1499
slice.insert (elseSlice.begin (), elseSlice.end ());
1508
1500
layout.insert (thenLayout.begin (), thenLayout.end ());
1509
1501
layout.insert (elseLayout.begin (), elseLayout.end ());
1510
- return ;
1502
+ continue ;
1511
1503
}
1512
1504
1513
1505
// If propagation across both edges failed, then this conditional
1514
1506
// terminates backwards rematerialization.
1515
1507
if (failed (thenResult) && failed (elseResult)) {
1516
- terminals.push_back (v);
1517
- return ;
1508
+ terminals.push_back (cast<OpResult>(v));
1509
+ continue ;
1510
+ }
1511
+
1512
+ // Only hoist into conditionals inside loops. The assumption is that an if
1513
+ // inside a loop executes fewer than the total number of loop iterations,
1514
+ // making this hoist profitable.
1515
+ if (!isa<scf::ForOp>(ifOp->getParentOp ())) {
1516
+ terminals.push_back (cast<OpResult>(v));
1517
+ continue ;
1518
1518
}
1519
1519
1520
1520
// The layout conversion can be rematerialized along one edge but not the
1521
1521
// other. We can hoist the conversion into the other branch. Push this
1522
1522
// into the subslice list for analysis.
1523
1523
if (succeeded (thenResult)) {
1524
- subslices.push_back (
1525
- {v, &elseRes, std::move (thenSlice), std::move (thenLayout)});
1524
+ hoistAbove.emplace_back (v, &elseRes);
1525
+ slice.insert (thenSlice.begin (), thenSlice.end ());
1526
+ layout.insert (thenLayout.begin (), thenLayout.end ());
1526
1527
} else {
1527
- subslices.push_back (
1528
- {v, &thenRes, std::move (elseSlice), std::move (elseLayout)});
1529
- }
1530
- };
1531
-
1532
- // Process the whole slice in subslices.
1533
- unsigned i = 0 ;
1534
- bool isLoneHoist = false ;
1535
- do {
1536
- // Visit values in the current subslice.
1537
- for (; i != slice.size (); ++i) {
1538
- if (auto v = dyn_cast<OpResult>(slice[i]))
1539
- visitValue (v);
1540
- }
1541
- // Check the next chunk of subslices. When a condtional is marked as being
1542
- // valid to be hoisted across, we have to recurse on a new subslice rooted
1543
- // at the corresopnding yield operand.
1544
- //
1545
- // Hoist across condtionals when:
1546
- // 1. The conditional is directly inside a loop.
1547
- // 2. The whole slice contains only one conditional.
1548
- for (auto &[v, edge, subslice, layouts] : subslices) {
1549
- bool oneHoist = false ;
1550
- if (isa<LoopLikeOpInterface>(v.getDefiningOp ()->getParentOp ()) ||
1551
- (oneHoist = subslices.size () == 1 && hoistAbove.empty ())) {
1552
- isLoneHoist |= oneHoist;
1553
- hoistAbove.push_back ({v, edge});
1554
- // Recurse on the subslice.
1555
- slice.insert (subslice.begin (), subslice.end ());
1556
- layout.insert (layouts.begin (), layouts.end ());
1557
- } else {
1558
- terminals.push_back (v);
1559
- }
1528
+ hoistAbove.emplace_back (v, &thenRes);
1529
+ slice.insert (elseSlice.begin (), elseSlice.end ());
1530
+ layout.insert (elseLayout.begin (), elseLayout.end ());
1560
1531
}
1561
- subslices.clear ();
1562
- } while (i != slice.size ());
1532
+ }
1563
1533
1564
1534
// Exit early if there is nothing to do.
1565
1535
if (hoistAbove.empty ())
1566
1536
return ;
1567
- // Check if this is a lone hoist. There should be no other terminals.
1568
- if (isLoneHoist && !terminals.empty ())
1569
- return ;
1570
1537
1571
1538
// Rematerialize failed hoists right before the condtional, and hoist those
1572
1539
// that succeeded into the branch and then rewrite the slice.
0 commit comments