@@ -154,17 +154,18 @@ class LayoutPropagation {
154154class LayoutRematerialization {
155155public:
156156 LayoutRematerialization (FuncOp F) : funcOp(F) {}
157+
157158 // Map the original value to the remat'ed one.
158159 void addRematValue (Value old, Attribute encoding, Value newV);
160+ // Get the remat'ed value in the given encoding, if one already exists and
161+ // is different then the layout conversion root.
162+ Value getRematValue (Value value, Attribute encoding) const {
163+ return rematMapping.lookup ({value, encoding});
164+ }
165+
159166 bool hasRematValue (Value value, Attribute encoding) {
160167 return rematMapping.contains ({value, encoding});
161168 }
162- // Return the remat'ed value in the given encoding.
163- Value getRematValue (Value value, Attribute encoding) {
164- auto it = rematMapping.find ({value, encoding});
165- assert (it != rematMapping.end ());
166- return it->second ;
167- }
168169 void cleanup ();
169170 void backwardRematerialization ();
170171 void backwardRematerialization (ConvertLayoutOp convertOp);
@@ -175,6 +176,11 @@ class LayoutRematerialization {
175176 void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
176177 ConvertLayoutOp convertOp);
177178
179+ LogicalResult getRematerializableSlice (
180+ OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
181+ DenseMap<Value, Attribute> &layout,
182+ std::function<bool (Operation *)> stopPropagation = nullptr);
183+
178184private:
179185 void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
180186 // Existing tuples of (value, layout) that needs to be updated when recreating
@@ -186,6 +192,7 @@ class LayoutRematerialization {
186192 // DenseMap<std::pair<Operation*, Attribute>, Operation*>
187193 SetVector<Operation *> opToDelete;
188194 FuncOp funcOp;
195+ DominanceInfo domInfo;
189196};
190197
191198void LayoutRematerialization::addRematValue (Value old, Attribute encoding,
@@ -1188,10 +1195,33 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11881195 rewriteSlice (slice, layout, convertOp, mapping);
11891196}
11901197
1191- LogicalResult getRematerializableSlice (
1192- Value root, Attribute rootEncoding, SetVector<Value> &slice,
1198+ LogicalResult LayoutRematerialization:: getRematerializableSlice (
1199+ OpOperand & root, Attribute rootEncoding, SetVector<Value> &slice,
11931200 DenseMap<Value, Attribute> &layout,
1194- std::function<bool (Operation *)> stopPropagation = nullptr) {
1201+ std::function<bool (Operation *)> stopPropagation) {
1202+ // Allow re-using existing conversions for a value. Check dominance of any
1203+ // reusable materializations against the root value. This is sufficient
1204+ // because the conversions are processed in post-order.
1205+ auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
1206+ Value remat = getRematValue (value.get (), encoding);
1207+ if (!remat)
1208+ return Value ();
1209+ // `value` can be replaced with an existing rematerialization if it
1210+ // dominates the current use of value.
1211+ Operation *user = value.getOwner ();
1212+ if (domInfo.properlyDominates (remat, user)) {
1213+ return remat;
1214+ }
1215+ // Alternatively, if the current use can be sunk below the existing
1216+ // rematerialization, then it is okay to use as well. E.g. the current use
1217+ // is a conversion that will be folded away when its result is
1218+ // rematerialized.
1219+ if (isa<ConvertLayoutOp>(user) && remat.getDefiningOp () &&
1220+ domInfo.properlyDominates (user, remat.getDefiningOp ())) {
1221+ return remat;
1222+ }
1223+ return Value ();
1224+ };
11951225 LogicalResult result = ttgi::getConvertBackwardSlice (
11961226 root, slice, rootEncoding, layout, std::move (stopPropagation));
11971227 if (result.failed () || slice.empty ())
@@ -1255,7 +1285,7 @@ void LayoutRematerialization::backwardRematerialization(
12551285 SetVector<Value> slice;
12561286 DenseMap<Value, Attribute> layout;
12571287 LogicalResult result = getRematerializableSlice (
1258- convertOp.getSrc (), targetType.getEncoding (), slice, layout);
1288+ convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
12591289 if (result.failed ()) {
12601290 LDBG (" getRematerializableSlice failed" );
12611291 return ;
@@ -1287,9 +1317,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
12871317 // 1. Take a backward slice of all the tensor dependencies.
12881318 SetVector<Value> slice;
12891319 DenseMap<Value, Attribute> layout;
1290- LogicalResult result =
1291- getRematerializableSlice ( convertOp.getSrc (), targetType.getEncoding (),
1292- slice, layout, isExtOrBroadcastOp);
1320+ LogicalResult result = getRematerializableSlice (
1321+ convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout ,
1322+ isExtOrBroadcastOp);
12931323 if (result.failed ())
12941324 return ;
12951325
@@ -1307,7 +1337,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
13071337 if (!srcEncoding)
13081338 return ;
13091339 LogicalResult result = getRematerializableSlice (
1310- op->getOperand (0 ), srcEncoding, tempSlice, tempLayout);
1340+ op->getOpOperand (0 ), srcEncoding, tempSlice, tempLayout);
13111341 // If we can rematerialize the rest of the ext slice we can ignore this
13121342 // ext as it won't need a convert.
13131343 if (result.succeeded ()) {
0 commit comments