11#include " mlir/Analysis/SliceAnalysis.h"
22#include " mlir/Dialect/SCF/IR/SCF.h"
33#include " mlir/IR/BuiltinAttributes.h"
4+ #include " mlir/IR/Dominance.h"
45#include " mlir/IR/IRMapping.h"
56#include " mlir/IR/PatternMatch.h"
67#include " mlir/IR/Verifier.h"
@@ -125,9 +126,6 @@ class LayoutRematerialization {
125126 return rematMapping.lookup ({value, encoding});
126127 }
127128
128- bool hasRematValue (Value value, Attribute encoding) {
129- return rematMapping.contains ({value, encoding});
130- }
131129 void cleanup ();
132130 void backwardRematerialization ();
133131 void backwardRematerialization (ConvertLayoutOp convertOp);
@@ -987,8 +985,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
987985 auto layoutIt = layout.find (v);
988986 assert (layoutIt != layout.end ());
989987 // If we already have a remat value for this value, use it.
990- if (hasRematValue (v, layoutIt->second )) {
991- mapping.map (v, getRematValue (v, layoutIt-> second ) );
988+ if (Value remat = getRematValue (v, layoutIt->second )) {
989+ mapping.map (v, remat );
992990 valuesWithExistingRemat.insert (v);
993991 continue ;
994992 }
@@ -1212,6 +1210,12 @@ void LayoutRematerialization::backwardRematerialization() {
12121210 [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
12131211 for (ConvertLayoutOp convertOp : convertOps) {
12141212 backwardRematerialization (convertOp);
1213+ if (!opToDelete.contains (convertOp)) {
1214+ // If the conversion didn't get removed, consider it for reuse in future
1215+ // backward slices.
1216+ addRematValue (convertOp.getSrc (), convertOp.getType ().getEncoding (),
1217+ convertOp.getResult ());
1218+ }
12151219 }
12161220}
12171221
@@ -1222,6 +1226,12 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
12221226 [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
12231227 for (ConvertLayoutOp convertOp : convertOps) {
12241228 hoistConvertOnTopOfExtOrBroadcast (convertOp);
1229+ if (!opToDelete.contains (convertOp)) {
1230+ // If the conversion didn't get removed, consider it for reuse in future
1231+ // backward slices.
1232+ addRematValue (convertOp.getSrc (), convertOp.getType ().getEncoding (),
1233+ convertOp.getResult ());
1234+ }
12251235 }
12261236}
12271237
@@ -1234,14 +1244,14 @@ void LayoutRematerialization::backwardRematerialization(
12341244 dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding ()))
12351245 if (isa<BlockedEncodingAttr>(dotLayout.getParent ()))
12361246 return ;
1237- Value oldV = convertOp-> getOperand ( 0 );
1247+ Value oldV = convertOp. getSrc ( );
12381248 LDBG (" check backward remat with source " << oldV << " encoding "
12391249 << targetType.getEncoding ());
12401250 // Check to see if there are existing remat'ed values for the pair of oldValue
1241- // and encoding.
1242- if (hasRematValue (oldV, targetType.getEncoding ())) {
1251+ // and encoding. Make sure it dominates the current conversion.
1252+ Value newV = getRematValue (oldV, targetType.getEncoding ());
1253+ if (newV && domInfo.properlyDominates (newV, convertOp)) {
12431254 // Replace it with the remat'ed value.
1244- Value newV = getRematValue (oldV, targetType.getEncoding ());
12451255 convertOp.replaceAllUsesWith (newV);
12461256 opToDelete.insert (convertOp);
12471257 LDBG (" found remat'ed value" << newV);
0 commit comments