Skip to content

Commit d7490e4

Browse files
[BACKEND] Enhance the remove layout implementation to reduce the duplicated values with different layout in scf.for. (#4527)
1 parent 6d69871 commit d7490e4

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,9 @@ class LayoutRematerialization {
172172
void reduceLoopCarriedValues();
173173
// Existing tuples of (value, layout) that needs to be updated when recreating
174174
// scf ops. This prevents keeping track of Values that have been delete when
175-
// rewriting slices.
176-
DenseMap<Value, Attribute> mappedValues;
175+
// rewriting slices. The Value maybe mapped to different attributes in remove
176+
// layout.
177+
DenseMap<Value, SmallVector<Attribute>> mappedValues;
177178
// map of the values remat based on encoding.
178179
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
179180
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -187,7 +188,10 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
187188
Value newV) {
188189
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
189190
rematMapping[{old, encoding}] = newV;
190-
mappedValues[old] = encoding;
191+
if (mappedValues.contains(old))
192+
mappedValues[old].push_back(encoding);
193+
else
194+
mappedValues[old] = {encoding};
191195
}
192196

193197
// Remove unneeded values now that we are done with the rematMapping.
@@ -992,22 +996,27 @@ void LayoutRematerialization::updateRematMapping(
992996
for (auto [old, newV] : values) {
993997
auto it = mappedValues.find(old);
994998
if (it != mappedValues.end()) {
995-
Attribute encoding = it->second;
996-
auto rematIt = rematMapping.find({old, it->second});
997-
assert(rematIt != rematMapping.end());
998-
Value replacedValue = rematIt->second;
999-
rematMapping.erase(rematIt);
1000-
mappedValues.erase(it);
1001-
// Loop through the replacement value to find the new version of remat
1002-
// value. This should be okay as the number of values should be small.
1003-
for (auto [before, after] : values) {
1004-
if (before == replacedValue) {
1005-
replacedValue = after;
1006-
break;
999+
SmallVector<Attribute> encodings = it->second;
1000+
for (Attribute encoding : encodings) {
1001+
auto rematIt = rematMapping.find({old, encoding});
1002+
assert(rematIt != rematMapping.end());
1003+
Value replacedValue = rematIt->second;
1004+
rematMapping.erase(rematIt);
1005+
// Loop through the replacement value to find the new version of remat
1006+
// value. This should be okay as the number of values should be small.
1007+
for (auto [before, after] : values) {
1008+
if (before == replacedValue) {
1009+
replacedValue = after;
1010+
break;
1011+
}
10071012
}
1013+
rematMapping[{newV, encoding}] = replacedValue;
10081014
}
1009-
rematMapping[{newV, encoding}] = replacedValue;
1010-
mappedValues[newV] = encoding;
1015+
mappedValues.erase(it);
1016+
if (mappedValues.contains(newV))
1017+
mappedValues[newV].append(encodings);
1018+
else
1019+
mappedValues[newV] = std::move(encodings);
10111020
}
10121021
}
10131022
}

0 commit comments

Comments
 (0)