@@ -135,6 +135,8 @@ class LayoutRematerialization {
135135 void hoistConvertDotOperand (ConvertLayoutOp convertOp);
136136 void hoistConvertOnTopOfExtOrBroadcast ();
137137 void hoistConvertOnTopOfExtOrBroadcast (ConvertLayoutOp convertOp);
138+ void hoistConvertIntoConditionals ();
139+ void hoistConvertIntoConditionals (ConvertLayoutOp convertOp);
138140 void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
139141 ConvertLayoutOp convertOp, IRMapping &mapping);
140142 void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
@@ -1042,6 +1044,22 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
10421044 }
10431045}
10441046
1047+ void LayoutRematerialization::hoistConvertIntoConditionals () {
1048+ // Go through each ConvertLayoutOp.
1049+ SmallVector<ConvertLayoutOp> convertOps;
1050+ funcOp.walk (
1051+ [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
1052+ for (ConvertLayoutOp convertOp : convertOps) {
1053+ hoistConvertIntoConditionals (convertOp);
1054+ if (!opToDelete.contains (convertOp)) {
1055+ // If the conversion didn't get removed, consider it for reuse in future
1056+ // backward slices.
1057+ addRematValue (convertOp.getSrc (), convertOp.getType ().getEncoding (),
1058+ convertOp.getResult ());
1059+ }
1060+ }
1061+ }
1062+
10451063void LayoutRematerialization::backwardRematerialization (
10461064 ConvertLayoutOp convertOp) {
10471065 // DotOperand is hoisted by hoistDotOperand
@@ -1268,6 +1286,155 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
12681286 rewriteSlice (slice, layout, convertOp, mapping);
12691287}
12701288
1289+ void LayoutRematerialization::hoistConvertIntoConditionals (
1290+ ConvertLayoutOp convertOp) {
1291+ // Take the backward slice of tensor dependencies rooted at the conversion,
1292+ // stopping at conditionals. This subslice is used to initialize the analysis.
1293+ SetVector<Value> slice;
1294+ DenseMap<Value, Attribute> layout;
1295+ auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
1296+ if (failed (getRematerializableSlice (convertOp.getSrcMutable (),
1297+ convertOp.getType ().getEncoding (), slice,
1298+ layout, isIfOp)))
1299+ return ;
1300+
1301+ // These are the conditional edges above which conversions should be hoisted.
1302+ // The value represents the `scf.if` op result and the operand represents the
1303+ // edge into one of the branches.
1304+ SmallVector<std::pair<OpResult, OpOperand *>> hoistAbove;
1305+
1306+ // The list of `scf.if` op results in the slice that are not rematerializable.
1307+ // Hoisting is terminated at these values.
1308+ SmallVector<OpResult> terminals;
1309+
1310+ // Process the whole backward slice in subslices that stop at each condtional.
1311+ // This is so we can apply more specific rules about when to hoist.
1312+ struct Subslice {
1313+ OpResult v;
1314+ OpOperand *edge;
1315+ SetVector<Value> slice;
1316+ DenseMap<Value, Attribute> layout;
1317+ };
1318+ SmallVector<Subslice> subslices;
1319+
1320+ // Check a value in the subslice.
1321+ auto visitValue = [&](OpResult v) {
1322+ auto ifOp = v.getDefiningOp <scf::IfOp>();
1323+ if (!ifOp)
1324+ return ;
1325+
1326+ Attribute rootLayout = layout.at (v);
1327+ unsigned resIdx = cast<OpResult>(v).getResultNumber ();
1328+
1329+ // Take the backward slice along each branch.
1330+ auto thenYield =
1331+ cast<scf::YieldOp>(ifOp.getThenRegion ().front ().getTerminator ());
1332+ auto elseYield =
1333+ cast<scf::YieldOp>(ifOp.getElseRegion ().front ().getTerminator ());
1334+
1335+ OpOperand &thenRes = thenYield.getResultsMutable ()[resIdx];
1336+ OpOperand &elseRes = elseYield.getResultsMutable ()[resIdx];
1337+
1338+ SetVector<Value> thenSlice, elseSlice;
1339+ DenseMap<Value, Attribute> thenLayout, elseLayout;
1340+
1341+ LogicalResult thenResult = getRematerializableSlice (
1342+ thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
1343+ LogicalResult elseResult = getRematerializableSlice (
1344+ elseRes, rootLayout, elseSlice, elseLayout, isIfOp);
1345+
1346+ // If propagation across both edges of this conditional succeeded, then we
1347+ // don't need to hoist across it. Merge into the current slice.
1348+ if (succeeded (thenResult) && succeeded (elseResult)) {
1349+ slice.insert (thenSlice.begin (), thenSlice.end ());
1350+ slice.insert (elseSlice.begin (), elseSlice.end ());
1351+ layout.insert (thenLayout.begin (), thenLayout.end ());
1352+ layout.insert (elseLayout.begin (), elseLayout.end ());
1353+ return ;
1354+ }
1355+
1356+ // If propagation across both edges failed, then this conditional
1357+ // terminates backwards rematerialization.
1358+ if (failed (thenResult) && failed (elseResult)) {
1359+ terminals.push_back (v);
1360+ return ;
1361+ }
1362+
1363+ // The layout conversion can be rematerialized along one edge but not the
1364+ // other. We can hoist the conversion into the other branch. Push this
1365+ // into the subslice list for analysis.
1366+ if (succeeded (thenResult)) {
1367+ subslices.push_back (
1368+ {v, &elseRes, std::move (thenSlice), std::move (thenLayout)});
1369+ } else {
1370+ subslices.push_back (
1371+ {v, &thenRes, std::move (elseSlice), std::move (elseLayout)});
1372+ }
1373+ };
1374+
1375+ // Process the whole slice in subslices.
1376+ unsigned i = 0 ;
1377+ bool isLoneHoist = false ;
1378+ do {
1379+ // Visit values in the current subslice.
1380+ for (; i != slice.size (); ++i) {
1381+ if (auto v = dyn_cast<OpResult>(slice[i]))
1382+ visitValue (v);
1383+ }
1384+ // Check the next chunk of subslices. When a condtional is marked as being
1385+ // valid to be hoisted across, we have to recurse on a new subslice rooted
1386+ // at the corresopnding yield operand.
1387+ //
1388+ // Hoist across condtionals when:
1389+ // 1. The conditional is directly inside a loop.
1390+ // 2. The whole slice contains only one conditional.
1391+ for (auto &[v, edge, subslice, layouts] : subslices) {
1392+ bool oneHoist = false ;
1393+ if (isa<LoopLikeOpInterface>(v.getDefiningOp ()->getParentOp ()) ||
1394+ (oneHoist = subslices.size () == 1 && hoistAbove.empty ())) {
1395+ isLoneHoist |= oneHoist;
1396+ hoistAbove.push_back ({v, edge});
1397+ // Recurse on the subslice.
1398+ slice.insert (subslice.begin (), subslice.end ());
1399+ layout.insert (layouts.begin (), layouts.end ());
1400+ } else {
1401+ terminals.push_back (v);
1402+ }
1403+ }
1404+ subslices.clear ();
1405+ } while (i != slice.size ());
1406+
1407+ // Exit early if there is nothing to do.
1408+ if (hoistAbove.empty ())
1409+ return ;
1410+ // Check if this is a lone hoist. There should be no other terminals.
1411+ if (isLoneHoist && !terminals.empty ())
1412+ return ;
1413+
1414+ // Rematerialize failed hoists right before the condtional, and hoist those
1415+ // that succeeded into the branch and then rewrite the slice.
1416+ IRMapping mapping;
1417+ auto hoistRemat = [&](OpBuilder &b, Value v, Attribute encoding) {
1418+ auto tensorType = cast<RankedTensorType>(v.getType ());
1419+ auto newType = RankedTensorType::get (tensorType.getShape (),
1420+ tensorType.getElementType (), encoding);
1421+ Value newCvt = b.create <ConvertLayoutOp>(convertOp.getLoc (), newType, v);
1422+
1423+ mapping.map (v, newCvt);
1424+ slice.remove (v);
1425+ };
1426+ for (Value v : terminals) {
1427+ OpBuilder b (v.getContext ());
1428+ b.setInsertionPointAfter (v.getDefiningOp ());
1429+ hoistRemat (b, v, layout.at (v));
1430+ }
1431+ for (auto [result, edge] : hoistAbove) {
1432+ OpBuilder b (edge->getOwner ());
1433+ hoistRemat (b, edge->get (), layout.at (result));
1434+ }
1435+ rewriteSlice (slice, layout, convertOp, mapping);
1436+ }
1437+
12711438void backwardRematerialization (ModuleOp module ) {
12721439 module .walk ([](FuncOp funcOp) {
12731440 LayoutRematerialization layoutRemat (funcOp);
@@ -1283,6 +1450,10 @@ void hoistConvert(ModuleOp module) {
12831450 layoutRemat.hoistConvertOnTopOfExtOrBroadcast ();
12841451 layoutRemat.cleanup ();
12851452
1453+ layoutRemat = LayoutRematerialization (funcOp);
1454+ layoutRemat.hoistConvertIntoConditionals ();
1455+ layoutRemat.cleanup ();
1456+
12861457 layoutRemat = LayoutRematerialization (funcOp);
12871458 layoutRemat.hoistConvertDotOperand ();
12881459 layoutRemat.cleanup ();
0 commit comments