@@ -116,17 +116,13 @@ class LayoutPropagation {
116116class LayoutRematerialization {
117117public:
118118 LayoutRematerialization (FuncOp F) : funcOp(F) {}
119+
119120 // Map the original value to the remat'ed one.
120121 void addRematValue (Value old, Attribute encoding, Value newV);
121- bool hasRematValue (Value value, Attribute encoding) {
122- return rematMapping.contains ({value, encoding});
123- }
124- // Return the remat'ed value in the given encoding.
125- Value getRematValue (Value value, Attribute encoding) {
126- auto it = rematMapping.find ({value, encoding});
127- assert (it != rematMapping.end ());
128- return it->second ;
129- }
122+ // Get the remat'ed value in the given encoding, if one already exists and
123+ // is different then the layout conversion root.
124+ Value getRematValue (Value value, Attribute encoding, Value root) const ;
125+
130126 void cleanup ();
131127 void backwardRematerialization ();
132128 void backwardRematerialization (ConvertLayoutOp convertOp);
@@ -137,6 +133,11 @@ class LayoutRematerialization {
137133 void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
138134 ConvertLayoutOp convertOp);
139135
136+ LogicalResult getRematerializableSlice (
137+ Value root, Attribute rootEncoding, SetVector<Value> &slice,
138+ DenseMap<Value, Attribute> &layout,
139+ std::function<bool (Operation *)> stopPropagation = nullptr);
140+
140141private:
141142 void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
142143 // Existing tuples of (value, layout) that needs to be updated when recreating
@@ -157,6 +158,21 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
157158 mappedValues[old] = encoding;
158159}
159160
161+ Value LayoutRematerialization::getRematValue (Value value, Attribute encoding,
162+ Value root) const {
163+ Value remat = rematMapping.lookup ({value, encoding});
164+ if (!remat)
165+ return {};
166+ // If the remat'ed value is a conversion result, make sure it is different
167+ // than the root of the one we're looking at.
168+ if (auto cvt = remat.getDefiningOp <ConvertLayoutOp>()) {
169+ if (cvt.getSrc () == root)
170+ return {};
171+ }
172+ // This remat'ed value can be reused.
173+ return remat;
174+ }
175+
160176// Remove unneeded values now that we are done with the rematMapping.
161177void LayoutRematerialization::cleanup () {
162178 for (Operation *op : llvm::reverse (opToDelete))
@@ -766,8 +782,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
766782 auto layoutIt = layout.find (v);
767783 assert (layoutIt != layout.end ());
768784 // If we already have a remat value for this value, use it.
769- if (hasRematValue (v, layoutIt->second )) {
770- mapping.map (v, getRematValue (v, layoutIt-> second ) );
785+ if (Value remat = getRematValue (v, layoutIt->second , convertOp. getSrc () )) {
786+ mapping.map (v, remat );
771787 valuesWithExistingRemat.insert (v);
772788 continue ;
773789 }
@@ -928,12 +944,17 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
928944 rewriteSlice (slice, layout, convertOp, mapping);
929945}
930946
931- LogicalResult getRematerializableSlice (
947+ LogicalResult LayoutRematerialization:: getRematerializableSlice (
932948 Value root, Attribute rootEncoding, SetVector<Value> &slice,
933949 DenseMap<Value, Attribute> &layout,
934- std::function<bool (Operation *)> stopPropagation = nullptr) {
935- LogicalResult result = getConvertBackwardSlice (root, slice, rootEncoding,
936- layout, stopPropagation);
950+ std::function<bool (Operation *)> stopPropagation) {
951+ // Allow re-using existing conversions for a value.
952+ auto getExistingConversion = [&](Value value, Attribute encoding) -> Value {
953+ return getRematValue (value, encoding, root);
954+ };
955+ LogicalResult result =
956+ getConvertBackwardSlice (root, slice, rootEncoding, layout,
957+ stopPropagation, getExistingConversion);
937958 if (result.failed () || slice.empty ())
938959 return failure ();
939960
@@ -950,8 +971,14 @@ LogicalResult getRematerializableSlice(
950971void LayoutRematerialization::backwardRematerialization () {
951972 // Go through each ConvertLayoutOp.
952973 SmallVector<ConvertLayoutOp> convertOps;
953- funcOp.walk (
954- [&](ConvertLayoutOp convertOp) { convertOps.push_back (convertOp); });
974+ funcOp.walk ([&](ConvertLayoutOp convertOp) {
975+ convertOps.push_back (convertOp);
976+ // Add existing layout conversions as rematerializations of themselves. This
977+ // enables rematerialization of other conversions to re-use existing
978+ // conversions. Importantly, don't add them to `mappedValues`.
979+ rematMapping.insert (
980+ {{convertOp.getSrc (), convertOp.getType ().getEncoding ()}, convertOp});
981+ });
955982 for (ConvertLayoutOp convertOp : convertOps) {
956983 backwardRematerialization (convertOp);
957984 }
@@ -976,14 +1003,13 @@ void LayoutRematerialization::backwardRematerialization(
9761003 // careful with the heuristics for both correctness and perf
9771004 if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding ()))
9781005 return ;
979- Value oldV = convertOp-> getOperand ( 0 );
1006+ Value oldV = convertOp. getSrc ( );
9801007 LDBG (" check backward remat with source " << oldV << " encoding "
9811008 << targetType.getEncoding ());
9821009 // Check to see if there are existing remat'ed values for the pair of oldValue
9831010 // and encoding.
984- if (hasRematValue (oldV, targetType.getEncoding ())) {
1011+ if (Value newV = getRematValue (oldV, targetType.getEncoding (), oldV )) {
9851012 // Replace it with the remat'ed value.
986- Value newV = getRematValue (oldV, targetType.getEncoding ());
9871013 convertOp.replaceAllUsesWith (newV);
9881014 opToDelete.insert (convertOp);
9891015 LDBG (" found remat'ed value" << newV);
0 commit comments