@@ -149,19 +149,22 @@ class LayoutRematerialization {
149
149
getConvertBackwardSlice (OpOperand &root, Attribute rootEncoding,
150
150
SetVector<Value> &slice,
151
151
DenseMap<Value, Attribute> &layout,
152
- std::function<bool (Operation *)> stopPropagation);
152
+ std::function<bool (Operation *)> stopPropagation,
153
+ bool includeForOp = false );
153
154
154
155
LogicalResult getRematerializableSlice (
155
156
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
156
157
DenseMap<Value, Attribute> &layout,
157
- std::function<bool (Operation *)> stopPropagation = nullptr);
158
+ std::function<bool (Operation *)> stopPropagation = nullptr,
159
+ bool includeForOp = false);
158
160
159
161
private:
160
162
void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
161
163
// Existing tuples of (value, layout) that needs to be updated when recreating
162
164
// scf ops. This prevents keeping track of Values that have been delete when
163
- // rewriting slices.
164
- DenseMap<Value, Attribute> mappedValues;
165
+ // rewriting slices. The Value maybe mapped to different attributes in remove
166
+ // layout.
167
+ DenseMap<Value, SmallVector<Attribute>> mappedValues;
165
168
// map of the values remat based on encoding.
166
169
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
167
170
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -174,7 +177,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
174
177
Value newV) {
175
178
LDBG (" addRematValue " << old << " encoding " << encoding << " " << newV);
176
179
rematMapping[{old, encoding}] = newV;
177
- mappedValues[old] = encoding;
180
+ if (mappedValues.contains (old)) {
181
+ mappedValues[old].push_back (encoding);
182
+ } else {
183
+ mappedValues[old] = {encoding};
184
+ }
178
185
}
179
186
180
187
// Remove unneeded values now that we are done with the rematMapping.
@@ -955,22 +962,28 @@ void LayoutRematerialization::updateRematMapping(
955
962
for (auto [old, newV] : values) {
956
963
auto it = mappedValues.find (old);
957
964
if (it != mappedValues.end ()) {
958
- Attribute encoding = it->second ;
959
- auto rematIt = rematMapping.find ({old, it->second });
960
- assert (rematIt != rematMapping.end ());
961
- Value replacedValue = rematIt->second ;
962
- rematMapping.erase (rematIt);
963
- mappedValues.erase (it);
964
- // Loop through the replacement value to find the new version of remat
965
- // value. This should be okay as the number of values should be small.
966
- for (auto [before, after] : values) {
967
- if (before == replacedValue) {
968
- replacedValue = after;
969
- break ;
965
+ SmallVector<Attribute> encodings = it->second ;
966
+ for (auto encoding : encodings) {
967
+ auto rematIt = rematMapping.find ({old, encoding});
968
+ assert (rematIt != rematMapping.end ());
969
+ Value replacedValue = rematIt->second ;
970
+ rematMapping.erase (rematIt);
971
+ // Loop through the replacement value to find the new version of remat
972
+ // value. This should be okay as the number of values should be small.
973
+ for (auto [before, after] : values) {
974
+ if (before == replacedValue) {
975
+ replacedValue = after;
976
+ break ;
977
+ }
970
978
}
979
+ rematMapping[{newV, encoding}] = replacedValue;
980
+ }
981
+ mappedValues.erase (it);
982
+ if (mappedValues.contains (newV)) {
983
+ mappedValues[newV].append (encodings);
984
+ } else {
985
+ mappedValues[newV] = std::move (encodings);
971
986
}
972
- rematMapping[{newV, encoding}] = replacedValue;
973
- mappedValues[newV] = encoding;
974
987
}
975
988
}
976
989
}
@@ -1045,6 +1058,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
1045
1058
deadOps.push_back (forOp.getOperation ());
1046
1059
Block &loopBody = *newForOp.getBody ();
1047
1060
for (auto m : argMapping) {
1061
+ mapping.map (newForOp.getResult (m.first ), newForOp.getResult (m.second ));
1048
1062
mapping.map (forOp.getResult (m.first ), newForOp.getResult (m.second ));
1049
1063
int numIndVars = newForOp.getNumInductionVars ();
1050
1064
mapping.map (loopBody.getArgument (m.first + numIndVars),
@@ -1161,8 +1175,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
1161
1175
builder.replaceAllUsesWith (std::get<0 >(kv), std::get<1 >(kv));
1162
1176
}
1163
1177
1164
- for (Operation *op : deadOps)
1165
- opToDelete.insert (op);
1178
+ for (Operation *op : deadOps) {
1179
+ if (!isa<scf::ForOp>(op))
1180
+ opToDelete.insert (op);
1181
+ else
1182
+ op->erase ();
1183
+ }
1166
1184
}
1167
1185
1168
1186
void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
@@ -1175,7 +1193,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
1175
1193
LogicalResult LayoutRematerialization::getConvertBackwardSlice (
1176
1194
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
1177
1195
DenseMap<Value, Attribute> &layout,
1178
- std::function<bool (Operation *)> stopPropagation) {
1196
+ std::function<bool (Operation *)> stopPropagation, bool includeForOp ) {
1179
1197
// Allow re-using existing conversions for a value. Check dominance of any
1180
1198
// reusable materializations against the root value. This is sufficient
1181
1199
// because the conversions are processed in post-order.
@@ -1204,15 +1222,18 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
1204
1222
};
1205
1223
1206
1224
return ttgi::getConvertBackwardSlice (root, slice, rootEncoding, layout,
1207
- stopPropagation, getExistingConversion);
1225
+ stopPropagation, getExistingConversion,
1226
+ includeForOp);
1208
1227
}
1209
1228
1210
1229
LogicalResult LayoutRematerialization::getRematerializableSlice (
1211
1230
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
1212
1231
DenseMap<Value, Attribute> &layout,
1213
- std::function<bool (Operation *)> stopPropagation) {
1214
- LogicalResult result = getConvertBackwardSlice (
1215
- root, rootEncoding, slice, layout, std::move (stopPropagation));
1232
+ std::function<bool (Operation *)> stopPropagation, bool includeForOp) {
1233
+
1234
+ LogicalResult result =
1235
+ getConvertBackwardSlice (root, rootEncoding, slice, layout,
1236
+ std::move (stopPropagation), includeForOp);
1216
1237
if (result.failed () || slice.empty ())
1217
1238
return failure ();
1218
1239
@@ -1301,8 +1322,9 @@ void LayoutRematerialization::backwardRematerialization(
1301
1322
// rematerialized.
1302
1323
SetVector<Value> slice;
1303
1324
DenseMap<Value, Attribute> layout;
1304
- LogicalResult result = getRematerializableSlice (
1305
- convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
1325
+ LogicalResult result = getRematerializableSlice (convertOp.getSrcMutable (),
1326
+ targetType.getEncoding (),
1327
+ slice, layout, nullptr , true );
1306
1328
if (result.failed ()) {
1307
1329
LDBG (" getRematerializableSlice failed" );
1308
1330
return ;
0 commit comments