@@ -116,13 +116,17 @@ class LayoutPropagation {
116116class LayoutRematerialization {
117117public:
118118 LayoutRematerialization (FuncOp F) : funcOp(F) {}
119-
120119 // Map the original value to the remat'ed one.
121120 void addRematValue (Value old, Attribute encoding, Value newV);
122- // Get the remat'ed value in the given encoding, if one already exists and
123- // is different then the layout conversion root.
124- Value getRematValue (Value value, Attribute encoding, Value root) const ;
125-
121+ bool hasRematValue (Value value, Attribute encoding) {
122+ return rematMapping.contains ({value, encoding});
123+ }
124+ // Return the remat'ed value in the given encoding.
125+ Value getRematValue (Value value, Attribute encoding) {
126+ auto it = rematMapping.find ({value, encoding});
127+ assert (it != rematMapping.end ());
128+ return it->second ;
129+ }
126130 void cleanup ();
127131 void backwardRematerialization ();
128132 void backwardRematerialization (ConvertLayoutOp convertOp);
@@ -133,11 +137,6 @@ class LayoutRematerialization {
133137 void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
134138 ConvertLayoutOp convertOp);
135139
136- LogicalResult getRematerializableSlice (
137- Value root, Attribute rootEncoding, SetVector<Value> &slice,
138- DenseMap<Value, Attribute> &layout,
139- std::function<bool (Operation *)> stopPropagation = nullptr);
140-
141140private:
142141 void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
143142 // Existing tuples of (value, layout) that needs to be updated when recreating
@@ -158,21 +157,6 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
158157 mappedValues[old] = encoding;
159158}
160159
161- Value LayoutRematerialization::getRematValue (Value value, Attribute encoding,
162- Value root) const {
163- Value remat = rematMapping.lookup ({value, encoding});
164- if (!remat)
165- return {};
166- // If the remat'ed value is a conversion result, make sure it is different
167- // than the root of the one we're looking at.
168- if (auto cvt = remat.getDefiningOp <ConvertLayoutOp>()) {
169- if (cvt.getSrc () == root)
170- return {};
171- }
172- // This remat'ed value can be reused.
173- return remat;
174- }
175-
176160// Remove unneeded values now that we are done with the rematMapping.
177161void LayoutRematerialization::cleanup () {
178162 for (Operation *op : llvm::reverse (opToDelete))
@@ -794,8 +778,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
794778 auto layoutIt = layout.find (v);
795779 assert (layoutIt != layout.end ());
796780 // If we already have a remat value for this value, use it.
797- if (Value remat = getRematValue (v, layoutIt->second , convertOp. getSrc () )) {
798- mapping.map (v, remat );
781+ if (hasRematValue (v, layoutIt->second )) {
782+ mapping.map (v, getRematValue (v, layoutIt-> second ) );
799783 valuesWithExistingRemat.insert (v);
800784 continue ;
801785 }
@@ -956,17 +940,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
956940 rewriteSlice (slice, layout, convertOp, mapping);
957941}
958942
959- LogicalResult LayoutRematerialization:: getRematerializableSlice (
943+ LogicalResult getRematerializableSlice (
960944 Value root, Attribute rootEncoding, SetVector<Value> &slice,
961945 DenseMap<Value, Attribute> &layout,
962- std::function<bool (Operation *)> stopPropagation) {
963- // Allow re-using existing conversions for a value.
964- auto getExistingConversion = [&](Value value, Attribute encoding) -> Value {
965- return getRematValue (value, encoding, root);
966- };
967- LogicalResult result =
968- getConvertBackwardSlice (root, slice, rootEncoding, layout,
969- stopPropagation, getExistingConversion);
946+ std::function<bool (Operation *)> stopPropagation = nullptr) {
947+ LogicalResult result = getConvertBackwardSlice (root, slice, rootEncoding,
948+ layout, stopPropagation);
970949 if (result.failed () || slice.empty ())
971950 return failure ();
972951
@@ -983,14 +962,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
983962void LayoutRematerialization::backwardRematerialization () {
984963 // Go through each ConvertLayoutOp.
985964 SmallVector<ConvertLayoutOp> convertOps;
986- funcOp.walk ([&](ConvertLayoutOp convertOp) {
987- convertOps.push_back (convertOp);
988- // Add existing layout conversions as rematerializations of themselves. This
989- // enables rematerialization of other conversions to re-use existing
990- // conversions. Importantly, don't add them to `mappedValues`.
991- rematMapping.insert (
992- {{convertOp.getSrc (), convertOp.getType ().getEncoding ()}, convertOp});
993- });
965+ funcOp.walk (
966+ [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
994967 for (ConvertLayoutOp convertOp : convertOps) {
995968 backwardRematerialization (convertOp);
996969 }
@@ -1015,13 +988,14 @@ void LayoutRematerialization::backwardRematerialization(
1015988 // careful with the heuristics for both correctness and perf
1016989 if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding ()))
1017990 return ;
1018- Value oldV = convertOp. getSrc ( );
991+ Value oldV = convertOp-> getOperand ( 0 );
1019992 LDBG (" check backward remat with source " << oldV << " encoding "
1020993 << targetType.getEncoding ());
1021994 // Check to see if there are existing remat'ed values for the pair of oldValue
1022995 // and encoding.
1023- if (Value newV = getRematValue (oldV, targetType.getEncoding (), oldV )) {
996+ if (hasRematValue (oldV, targetType.getEncoding ())) {
1024997 // Replace it with the remat'ed value.
998+ Value newV = getRematValue (oldV, targetType.getEncoding ());
1025999 convertOp.replaceAllUsesWith (newV);
10261000 opToDelete.insert (convertOp);
10271001 LDBG (" found remat'ed value" << newV);
0 commit comments