diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 69036e947ebdb..4693edadfb5ee 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -71,10 +71,16 @@ namespace { /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { - /// Lookup a mapped value within the map. If a mapping for the provided value - /// does not exist then return the provided value. If `desiredType` is - /// non-null, returns the most recently mapped value with that type. If an - /// operand of that type does not exist, defaults to normal behavior. + /// Lookup the most recently mapped value with the desired type in the + /// mapping. + /// + /// Special cases: + /// - If the desired type is "null", simply return the most recently mapped + /// value. + /// - If there is no mapping to the desired type, also return the most + /// recently mapped value. + /// - If there is no mapping for the given value at all, return the given + /// value. Value lookupOrDefault(Value from, Type desiredType = nullptr) const; /// Lookup a mapped value within the map, or return null if a mapping does not @@ -115,19 +121,11 @@ struct ConversionValueMapping { Value ConversionValueMapping::lookupOrDefault(Value from, Type desiredType) const { - // If there was no desired type, simply find the leaf value. - if (!desiredType) { - // If this value had a valid mapping, unmap that value as well in the case - // that it was also replaced. - while (auto mappedValue = mapping.lookupOrNull(from)) - from = mappedValue; - return from; - } - - // Otherwise, try to find the deepest value that has the desired type. + // Try to find the deepest value that has the desired type. If there is no + // such value, simply return the deepest value. Value desiredValue; do { - if (from.getType() == desiredType) + if (!desiredType || from.getType() == desiredType) desiredValue = from; Value mappedValue = mapping.lookupOrNull(from); @@ -1136,7 +1134,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( MaterializationKind::Target, computeInsertPoint(newOperand), operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType, currentTypeConverter); - mapping.map(mapping.lookupOrDefault(newOperand), castValue); + mapping.map(newOperand, castValue); newOperand = castValue; } remapped.push_back(newOperand);