@@ -172,8 +172,9 @@ class LayoutRematerialization {
172172 void reduceLoopCarriedValues ();
173173 // Existing tuples of (value, layout) that needs to be updated when recreating
174174 // scf ops. This prevents keeping track of Values that have been delete when
175- // rewriting slices.
176- DenseMap<Value, Attribute> mappedValues;
175+ // rewriting slices. The Value maybe mapped to different attributes in remove
176+ // layout.
177+ DenseMap<Value, SmallVector<Attribute>> mappedValues;
177178 // map of the values remat based on encoding.
178179 DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
179180 // DenseMap<std::pair<Operation*, Attribute>, Operation*>
@@ -187,7 +188,10 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
187188 Value newV) {
188189 LDBG (" addRematValue " << old << " encoding " << encoding << " " << newV);
189190 rematMapping[{old, encoding}] = newV;
190- mappedValues[old] = encoding;
191+ if (mappedValues.contains (old))
192+ mappedValues[old].push_back (encoding);
193+ else
194+ mappedValues[old] = {encoding};
191195}
192196
193197// Remove unneeded values now that we are done with the rematMapping.
@@ -992,22 +996,27 @@ void LayoutRematerialization::updateRematMapping(
992996 for (auto [old, newV] : values) {
993997 auto it = mappedValues.find (old);
994998 if (it != mappedValues.end ()) {
995- Attribute encoding = it->second ;
996- auto rematIt = rematMapping.find ({old, it->second });
997- assert (rematIt != rematMapping.end ());
998- Value replacedValue = rematIt->second ;
999- rematMapping.erase (rematIt);
1000- mappedValues.erase (it);
1001- // Loop through the replacement value to find the new version of remat
1002- // value. This should be okay as the number of values should be small.
1003- for (auto [before, after] : values) {
1004- if (before == replacedValue) {
1005- replacedValue = after;
1006- break ;
999+ SmallVector<Attribute> encodings = it->second ;
1000+ for (Attribute encoding : encodings) {
1001+ auto rematIt = rematMapping.find ({old, encoding});
1002+ assert (rematIt != rematMapping.end ());
1003+ Value replacedValue = rematIt->second ;
1004+ rematMapping.erase (rematIt);
1005+ // Loop through the replacement value to find the new version of remat
1006+ // value. This should be okay as the number of values should be small.
1007+ for (auto [before, after] : values) {
1008+ if (before == replacedValue) {
1009+ replacedValue = after;
1010+ break ;
1011+ }
10071012 }
1013+ rematMapping[{newV, encoding}] = replacedValue;
10081014 }
1009- rematMapping[{newV, encoding}] = replacedValue;
1010- mappedValues[newV] = encoding;
1015+ mappedValues.erase (it);
1016+ if (mappedValues.contains (newV))
1017+ mappedValues[newV].append (encodings);
1018+ else
1019+ mappedValues[newV] = std::move (encodings);
10111020 }
10121021 }
10131022}
0 commit comments