Skip to content

[mlir][Transforms] Make lookup without type converter unambiguous #151747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 56 additions & 11 deletions mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,62 @@ struct MyConversionPattern : public ConversionPattern {

#### Type Safety

The types of the remapped operands provided to a conversion pattern must be of a
type expected by the pattern. The expected types of a pattern are determined by
a provided [TypeConverter](#type-converter). If no type converter is provided,
the types of the remapped operands are expected to match the types of the
original operands. If a type converter is provided, the types of the remapped
operands are expected to be legal as determined by the converter. If the
remapped operand types are not of an expected type, and a materialization to the
expected type could not be performed, the pattern fails application before the
`matchAndRewrite` hook is invoked. This ensures that patterns do not have to
explicitly ensure type safety, or sanitize the types of the incoming remapped
operands. More information on type conversion is detailed in the
The types of the remapped operands provided to a conversion pattern (through
the adaptor or `ArrayRef` of operands) depend on type conversion rules.

If the pattern was initialized with a [type converter](#type-converter), the
conversion driver passes values whose types match the legalized types of the
operands of the matched operation as per the type converter. To that end, the
conversion driver may insert target materializations to convert the most
recently mapped values to the expected legalized types. The driver tries to
reuse existing materializations on a best-effort basis, but this is not
guaranteed by the infrastructure. If the operand types of the matched op could
not be legalized, the pattern fails to apply before the `matchAndRewrite` hook
is invoked.

Example:
```c++
// Type converter that converts all FloatTypes to IntegerTypes.
TypeConverter converter;
converter.addConversion([](FloatType t) {
return IntegerType::get(t.getContext(), t.getWidth());
});

// Assuming that `MyConversionPattern` was initialized with `converter`.
struct MyConversionPattern : public ConversionPattern {
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, /* ... */) const {
// ^^^^^^^^
// If `op` has a FloatType operand, the respective value in `operands`
// is guaranteed to have the legalized IntegerType. If another pattern
// previously replaced the operand SSA value with an SSA value of the
// legalized type (via "replaceOp" or "applySignatureConversion"), you
// will get that SSA value directly (unless the replacement value was
// also replaced). Otherwise, you will get a materialization to the
// legalized type.
```

If the pattern was initialized without a type converter, the conversion driver
passes the most recently mapped values to the pattern, excluding any
materializations. If a value with the same type as the original operand is
desired, users can directly take the respective operand from the matched
operation.

Example: When initializing the pattern from the above example without a type
converter, `operands` contains the most recent replacement values, regardless
of their types.

Note: When running without a type converter, materializations are intentionally
excluded from the lookup process because their presence may depend on other
patterns. Passing materializations would make the conversion infrastructure
fragile and unpredictable. Moreover, there could be multiple materializations
to different types. (This can be the case when multiple patterns are running
with different type converters.) In such a case, it would be unclear which
materialization to pass.

The above rules ensure that patterns do not have to explicitly ensure type
safety, or sanitize the types of the incoming remapped operands. More
information on type conversion is detailed in the
[dedicated section](#type-conversion) below.

## Type Conversion
Expand Down
203 changes: 131 additions & 72 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,8 @@ struct ConversionValueMapping {
/// false positives.
bool isMappedTo(Value value) const { return mappedTo.contains(value); }

/// Lookup the most recently mapped values with the desired types in the
/// mapping.
///
/// Special cases:
/// - If the desired type range is empty, simply return the most recently
/// mapped values.
/// - If there is no mapping to the desired types, also return the most
/// recently mapped values.
/// - If there is no mapping for the given values at all, return the given
/// value.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
/// Lookup a value in the mapping.
ValueVector lookup(const ValueVector &from) const;

template <typename T>
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
Expand Down Expand Up @@ -185,54 +176,31 @@ struct ConversionValueMapping {
};
} // namespace

ValueVector
ConversionValueMapping::lookupOrDefault(Value from,
TypeRange desiredTypes) const {
// Try to find the deepest values that have the desired types. If there is no
// such mapping, simply return the deepest values.
ValueVector desiredValue;
ValueVector current{from};
do {
// Store the current value if the types match.
if (TypeRange(ValueRange(current)) == desiredTypes)
desiredValue = current;

// If possible, Replace each value with (one or multiple) mapped values.
ValueVector next;
for (Value v : current) {
auto it = mapping.find({v});
if (it != mapping.end()) {
llvm::append_range(next, it->second);
} else {
next.push_back(v);
}
}
if (next != current) {
// If at least one value was replaced, continue the lookup from there.
current = std::move(next);
continue;
}

// Otherwise: Check if there is a mapping for the entire vector. Such
// mappings are materializations. (N:M mapping are not supported for value
// replacements.)
//
// Note: From a correctness point of view, materializations do not have to
// be stored (and looked up) in the mapping. But for performance reasons,
// we choose to reuse existing IR (when possible) instead of creating it
// multiple times.
auto it = mapping.find(current);
if (it == mapping.end()) {
// No mapping found: The lookup stops here.
break;
}
current = it->second;
} while (true);
/// Marker attribute for pure type conversions. I.e., mappings whose only
/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
/// the replacement values of a "replaceOp" call, etc., are not pure type
/// conversions.)
static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";

/// A vector of values is a pure type conversion if all values are defined by
/// the same operation and the operation has the `kPureTypeConversionMarker`
/// attribute.
static bool isPureTypeConversion(const ValueVector &values) {
assert(!values.empty() && "expected non-empty value vector");
Operation *op = values.front().getDefiningOp();
for (Value v : llvm::drop_begin(values))
if (v.getDefiningOp() != op)
return false;
return op && op->hasAttr(kPureTypeConversionMarker);
}
Comment on lines +179 to +195
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the definition here, but it seems to me that it is basically a redefinition of target materializations. Do we need to distinguish between target and pure type conversions or can we just rename these and say we track and exclude target materializations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not just target materializations, but also source materializations. But not all source materializations. The source materializations that have no inputs and were created when erasing an op or dropping a block argument are excluded. I was struggling to find a good name for that.

The two kind of mappings that we have to distinguish between:

  1. User wants to replace value A with value B. That's a "regular" replacement.
  2. Driver notices that it needs to insert a materialization to make the IR pass type checking. That's a pure type conversion.


// If the desired values were found use them, otherwise default to the leaf
// values.
// Note: If `desiredTypes` is empty, this function always returns `current`.
return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
auto it = mapping.find(from);
if (it == mapping.end()) {
// No mapping found: The lookup stops here.
return {};
}
return it->second;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -930,7 +898,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// recently mapped values.
/// - If there is no mapping for the given values at all, return the given
/// value.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
///
/// If `skipPureTypeConversions` is "true", materializations that are pure
/// type conversions are not considered.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
bool skipPureTypeConversions = false) const;

/// Lookup the given value within the map, or return an empty vector if the
/// value is not mapped. If it is mapped, this follows the same behavior
Expand Down Expand Up @@ -993,11 +965,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// If `valuesToMap` is set to a non-null Value, then that value is mapped to
/// the results of the unresolved materialization in the conversion value
/// mapping.
///
/// If `isPureTypeConversion` is "true", the materialization is created only
/// to resolve a type mismatch. That means it is not a regular value
/// replacement issued by the user. (Replacement values that are created
/// "out of thin air" appear like unresolved materializations because they are
/// unrealized_conversion_cast ops. However, they must be treated like
/// regular value replacements.)
ValueRange buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
UnrealizedConversionCastOp *castOp = nullptr);
UnrealizedConversionCastOp *castOp = nullptr,
bool isPureTypeConversion = true);

/// Find a replacement value for the given SSA value in the conversion value
/// mapping. The replacement value must have the same type as the given SSA
Expand Down Expand Up @@ -1264,10 +1244,77 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management
//===----------------------------------------------------------------------===//

ValueVector
ConversionPatternRewriterImpl::lookupOrDefault(Value from,
TypeRange desiredTypes) const {
return mapping.lookupOrDefault(from, desiredTypes);
ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
// Helper function that looks up each value in `values` individually and then
// composes the results. If that fails, it tries to look up the entire vector
// at once.
auto composedLookup = [&](const ValueVector &values) -> ValueVector {
// If possible, replace each value with (one or multiple) mapped values.
ValueVector next;
for (Value v : values) {
ValueVector r = mapping.lookup({v});
if (!r.empty()) {
llvm::append_range(next, r);
} else {
next.push_back(v);
}
}
if (next != values) {
// At least one value was replaced.
return next;
}

// Otherwise: Check if there is a mapping for the entire vector. Such
// mappings are materializations. (N:M mapping are not supported for value
// replacements.)
//
// Note: From a correctness point of view, materializations do not have to
// be stored (and looked up) in the mapping. But for performance reasons,
// we choose to reuse existing IR (when possible) instead of creating it
// multiple times.
ValueVector r = mapping.lookup(values);
if (r.empty()) {
// No mapping found: The lookup stops here.
return {};
}
return r;
};

// Try to find the deepest values that have the desired types. If there is no
// such mapping, simply return the deepest values.
ValueVector desiredValue;
ValueVector current{from};
ValueVector lastNonMaterialization{from};
do {
// Store the current value if the types match.
bool match = TypeRange(ValueRange(current)) == desiredTypes;
if (skipPureTypeConversions) {
// Skip pure type conversions, if requested.
bool pureConversion = isPureTypeConversion(current);
match &= !pureConversion;
// Keep track of the last mapped value that was not a pure type
// conversion.
if (!pureConversion)
lastNonMaterialization = current;
}
if (match)
desiredValue = current;

// Lookup next value in the mapping.
ValueVector next = composedLookup(current);
if (next.empty())
break;
current = std::move(next);
} while (true);

// If the desired values were found use them, otherwise default to the leaf
// values. (Skip pure type conversions, if requested.)
if (!desiredTypes.empty())
return desiredValue;
if (skipPureTypeConversions)
return lastNonMaterialization;
return current;
}

ValueVector
Expand Down Expand Up @@ -1324,10 +1371,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();

if (!currentTypeConverter) {
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped values.
remapped.push_back(lookupOrDefault(operand));
// The current pattern does not have a type converter. Pass the most
// recently mapped values, excluding materializations. Materializations
// are intentionally excluded because their presence may depend on other
// patterns. Including materializations would make the lookup fragile
// and unpredictable.
remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
/*skipPureTypeConversions=*/true));
continue;
}

Expand Down Expand Up @@ -1356,7 +1406,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
}

// Create a materialization for the most recently mapped values.
repl = lookupOrDefault(operand);
repl = lookupOrDefault(operand, /*desiredTypes=*/{},
/*skipPureTypeConversions=*/true);
ValueRange castValues = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
Expand Down Expand Up @@ -1482,7 +1533,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
origArg.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
/*castOp=*/nullptr, /*isPureTypeConversion=*/false)
.front();
replaceUsesOfBlockArgument(origArg, mat, converter);
continue;
Expand Down Expand Up @@ -1523,7 +1575,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
UnrealizedConversionCastOp *castOp) {
UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
assert(TypeRange(inputs) != outputTypes &&
Expand All @@ -1535,6 +1587,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
if (isPureTypeConversion)
convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
if (!valuesToMap.empty())
mapping.map(valuesToMap, convertOp.getResults());
if (castOp)
Expand Down Expand Up @@ -1650,7 +1704,8 @@ void ConversionPatternRewriterImpl::replaceOp(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter);
currentTypeConverter, /*castOp=*/nullptr,
/*isPureTypeConversion=*/false);
continue;
}

Expand Down Expand Up @@ -2902,6 +2957,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);

// Drop markers.
for (UnrealizedConversionCastOp castOp : remainingCastOps)
castOp->removeAttr(kPureTypeConversionMarker);

// Try to legalize all unresolved materializations.
if (config.buildMaterializations) {
IRRewriter rewriter(rewriterImpl.context, config.listener);
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() {
%0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
"test.invalid"(%0) : (f16) -> ()
}

// -----

// CHECK-LABEL: func @test_lookup_without_converter
// CHECK: %[[producer:.*]] = "test.valid_producer"() : () -> i16
// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
func.func @test_lookup_without_converter() {
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
// Make sure that the second "replace_with_valid_consumer" lowering does not
// lookup the materialization that was created for the above op.
"test.replace_with_valid_consumer"(%0) : (i64) -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
}
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
Arguments<(ins Variadic<AnyType>)>;
def TestTypeProducerOp : TEST_Op<"type_producer">,
Results<(outs AnyType)>;
def TestValidProducerOp : TEST_Op<"valid_producer">,
Results<(outs AnyType)>;
def TestValidConsumerOp : TEST_Op<"valid_consumer">,
Arguments<(ins AnyType)>;
def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
Results<(outs AnyType)>;
def TestTypeConsumerOp : TEST_Op<"type_consumer">,
Expand Down
Loading