Skip to content

Commit 71832a3

Browse files
[mlir][Transforms] Make lookup without type converter unambiguous (#151747)
When a conversion pattern is initialized without a type converter, the driver implementation currently looks up the most recently mapped value. This is undesirable because the most recently mapped value could be a materialization. I.e., the type of the value being looked up could depend on which other patterns have run before. Such an implementation makes the type conversion infrastructure fragile and unpredictable. The current implementation also contradicts the documentation in the markdown file. According to that documentation, the values provided by the adaptor should match the types of the operands of the match operation when running without a type converter. This mechanism is not desirable, either, for two reasons: 1. Some patterns have started to rely on receiving the most recently mapped value. Changing the behavior to the documented behavior will cause regressions. (And there would be no easy way to fix those without forcing the use of a type converter or extending the `getRemappedValue` API.) 2. It is more useful to receive the most recently mapped value. A value of the original operand type can be retrieved by using the operand of the matched operation. The adaptor is not needed at all in that case. To implement the new behavior, materializations are now annotated with a marker attribute. The marker is needed because not all `unrealized_conversion_cast` ops are materializations that act as "pure type conversions". E.g., when erasing an operation, its results are mapped to newly-created "out-of-thin-air values", which are materializations (with no input) that should be treated like regular replacement values during a lookup. This marker-based lookup strategy is also compatible with the One-Shot Dialect Conversion implementation strategy, which does not utilize the mapping infrastructure anymore and queries all necessary information by examining the IR.
1 parent 0a72e6d commit 71832a3

File tree

5 files changed

+255
-85
lines changed

5 files changed

+255
-85
lines changed

mlir/docs/DialectConversion.md

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,62 @@ struct MyConversionPattern : public ConversionPattern {
202202
203203
#### Type Safety
204204
205-
The types of the remapped operands provided to a conversion pattern must be of a
206-
type expected by the pattern. The expected types of a pattern are determined by
207-
a provided [TypeConverter](#type-converter). If no type converter is provided,
208-
the types of the remapped operands are expected to match the types of the
209-
original operands. If a type converter is provided, the types of the remapped
210-
operands are expected to be legal as determined by the converter. If the
211-
remapped operand types are not of an expected type, and a materialization to the
212-
expected type could not be performed, the pattern fails application before the
213-
`matchAndRewrite` hook is invoked. This ensures that patterns do not have to
214-
explicitly ensure type safety, or sanitize the types of the incoming remapped
215-
operands. More information on type conversion is detailed in the
205+
The types of the remapped operands provided to a conversion pattern (through
206+
the adaptor or `ArrayRef` of operands) depend on type conversion rules.
207+
208+
If the pattern was initialized with a [type converter](#type-converter), the
209+
conversion driver passes values whose types match the legalized types of the
210+
operands of the matched operation as per the type converter. To that end, the
211+
conversion driver may insert target materializations to convert the most
212+
recently mapped values to the expected legalized types. The driver tries to
213+
reuse existing materializations on a best-effort basis, but this is not
214+
guaranteed by the infrastructure. If the operand types of the matched op could
215+
not be legalized, the pattern fails to apply before the `matchAndRewrite` hook
216+
is invoked.
217+
218+
Example:
219+
```c++
220+
// Type converter that converts all FloatTypes to IntegerTypes.
221+
TypeConverter converter;
222+
converter.addConversion([](FloatType t) {
223+
return IntegerType::get(t.getContext(), t.getWidth());
224+
});
225+
226+
// Assuming that `MyConversionPattern` was initialized with `converter`.
227+
struct MyConversionPattern : public ConversionPattern {
228+
virtual LogicalResult
229+
matchAndRewrite(Operation *op, ArrayRef<Value> operands, /* ... */) const {
230+
// ^^^^^^^^
231+
// If `op` has a FloatType operand, the respective value in `operands`
232+
// is guaranteed to have the legalized IntegerType. If another pattern
233+
// previously replaced the operand SSA value with an SSA value of the
234+
// legalized type (via "replaceOp" or "applySignatureConversion"), you
235+
// will get that SSA value directly (unless the replacement value was
236+
// also replaced). Otherwise, you will get a materialization to the
237+
// legalized type.
238+
```
239+
240+
If the pattern was initialized without a type converter, the conversion driver
241+
passes the most recently mapped values to the pattern, excluding any
242+
materializations. If a value with the same type as the original operand is
243+
desired, users can directly take the respective operand from the matched
244+
operation.
245+
246+
Example: When initializing the pattern from the above example without a type
247+
converter, `operands` contains the most recent replacement values, regardless
248+
of their types.
249+
250+
Note: When running without a type converter, materializations are intentionally
251+
excluded from the lookup process because their presence may depend on other
252+
patterns. Passing materializations would make the conversion infrastructure
253+
fragile and unpredictable. Moreover, there could be multiple materializations
254+
to different types. (This can be the case when multiple patterns are running
255+
with different type converters.) In such a case, it would be unclear which
256+
materialization to pass.
257+
258+
The above rules ensure that patterns do not have to explicitly ensure type
259+
safety, or sanitize the types of the incoming remapped operands. More
260+
information on type conversion is detailed in the
216261
[dedicated section](#type-conversion) below.
217262

218263
## Type Conversion

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 131 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,8 @@ struct ConversionValueMapping {
121121
/// false positives.
122122
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
123123

124-
/// Lookup the most recently mapped values with the desired types in the
125-
/// mapping.
126-
///
127-
/// Special cases:
128-
/// - If the desired type range is empty, simply return the most recently
129-
/// mapped values.
130-
/// - If there is no mapping to the desired types, also return the most
131-
/// recently mapped values.
132-
/// - If there is no mapping for the given values at all, return the given
133-
/// value.
134-
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
124+
/// Lookup a value in the mapping.
125+
ValueVector lookup(const ValueVector &from) const;
135126

136127
template <typename T>
137128
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -185,54 +176,31 @@ struct ConversionValueMapping {
185176
};
186177
} // namespace
187178

188-
ValueVector
189-
ConversionValueMapping::lookupOrDefault(Value from,
190-
TypeRange desiredTypes) const {
191-
// Try to find the deepest values that have the desired types. If there is no
192-
// such mapping, simply return the deepest values.
193-
ValueVector desiredValue;
194-
ValueVector current{from};
195-
do {
196-
// Store the current value if the types match.
197-
if (TypeRange(ValueRange(current)) == desiredTypes)
198-
desiredValue = current;
199-
200-
// If possible, Replace each value with (one or multiple) mapped values.
201-
ValueVector next;
202-
for (Value v : current) {
203-
auto it = mapping.find({v});
204-
if (it != mapping.end()) {
205-
llvm::append_range(next, it->second);
206-
} else {
207-
next.push_back(v);
208-
}
209-
}
210-
if (next != current) {
211-
// If at least one value was replaced, continue the lookup from there.
212-
current = std::move(next);
213-
continue;
214-
}
215-
216-
// Otherwise: Check if there is a mapping for the entire vector. Such
217-
// mappings are materializations. (N:M mapping are not supported for value
218-
// replacements.)
219-
//
220-
// Note: From a correctness point of view, materializations do not have to
221-
// be stored (and looked up) in the mapping. But for performance reasons,
222-
// we choose to reuse existing IR (when possible) instead of creating it
223-
// multiple times.
224-
auto it = mapping.find(current);
225-
if (it == mapping.end()) {
226-
// No mapping found: The lookup stops here.
227-
break;
228-
}
229-
current = it->second;
230-
} while (true);
179+
/// Marker attribute for pure type conversions. I.e., mappings whose only
180+
/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
181+
/// the replacement values of a "replaceOp" call, etc., are not pure type
182+
/// conversions.)
183+
static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
184+
185+
/// A vector of values is a pure type conversion if all values are defined by
186+
/// the same operation and the operation has the `kPureTypeConversionMarker`
187+
/// attribute.
188+
static bool isPureTypeConversion(const ValueVector &values) {
189+
assert(!values.empty() && "expected non-empty value vector");
190+
Operation *op = values.front().getDefiningOp();
191+
for (Value v : llvm::drop_begin(values))
192+
if (v.getDefiningOp() != op)
193+
return false;
194+
return op && op->hasAttr(kPureTypeConversionMarker);
195+
}
231196

232-
// If the desired values were found use them, otherwise default to the leaf
233-
// values.
234-
// Note: If `desiredTypes` is empty, this function always returns `current`.
235-
return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
197+
ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
198+
auto it = mapping.find(from);
199+
if (it == mapping.end()) {
200+
// No mapping found: The lookup stops here.
201+
return {};
202+
}
203+
return it->second;
236204
}
237205

238206
//===----------------------------------------------------------------------===//
@@ -930,7 +898,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
930898
/// recently mapped values.
931899
/// - If there is no mapping for the given values at all, return the given
932900
/// value.
933-
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
901+
///
902+
/// If `skipPureTypeConversions` is "true", materializations that are pure
903+
/// type conversions are not considered.
904+
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
905+
bool skipPureTypeConversions = false) const;
934906

935907
/// Lookup the given value within the map, or return an empty vector if the
936908
/// value is not mapped. If it is mapped, this follows the same behavior
@@ -993,11 +965,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
993965
/// If `valuesToMap` is set to a non-null Value, then that value is mapped to
994966
/// the results of the unresolved materialization in the conversion value
995967
/// mapping.
968+
///
969+
/// If `isPureTypeConversion` is "true", the materialization is created only
970+
/// to resolve a type mismatch. That means it is not a regular value
971+
/// replacement issued by the user. (Replacement values that are created
972+
/// "out of thin air" appear like unresolved materializations because they are
973+
/// unrealized_conversion_cast ops. However, they must be treated like
974+
/// regular value replacements.)
996975
ValueRange buildUnresolvedMaterialization(
997976
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
998977
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
999978
Type originalType, const TypeConverter *converter,
1000-
UnrealizedConversionCastOp *castOp = nullptr);
979+
UnrealizedConversionCastOp *castOp = nullptr,
980+
bool isPureTypeConversion = true);
1001981

1002982
/// Find a replacement value for the given SSA value in the conversion value
1003983
/// mapping. The replacement value must have the same type as the given SSA
@@ -1264,10 +1244,77 @@ void ConversionPatternRewriterImpl::applyRewrites() {
12641244
// State Management
12651245
//===----------------------------------------------------------------------===//
12661246

1267-
ValueVector
1268-
ConversionPatternRewriterImpl::lookupOrDefault(Value from,
1269-
TypeRange desiredTypes) const {
1270-
return mapping.lookupOrDefault(from, desiredTypes);
1247+
ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
1248+
Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1249+
// Helper function that looks up each value in `values` individually and then
1250+
// composes the results. If that fails, it tries to look up the entire vector
1251+
// at once.
1252+
auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1253+
// If possible, replace each value with (one or multiple) mapped values.
1254+
ValueVector next;
1255+
for (Value v : values) {
1256+
ValueVector r = mapping.lookup({v});
1257+
if (!r.empty()) {
1258+
llvm::append_range(next, r);
1259+
} else {
1260+
next.push_back(v);
1261+
}
1262+
}
1263+
if (next != values) {
1264+
// At least one value was replaced.
1265+
return next;
1266+
}
1267+
1268+
// Otherwise: Check if there is a mapping for the entire vector. Such
1269+
// mappings are materializations. (N:M mapping are not supported for value
1270+
// replacements.)
1271+
//
1272+
// Note: From a correctness point of view, materializations do not have to
1273+
// be stored (and looked up) in the mapping. But for performance reasons,
1274+
// we choose to reuse existing IR (when possible) instead of creating it
1275+
// multiple times.
1276+
ValueVector r = mapping.lookup(values);
1277+
if (r.empty()) {
1278+
// No mapping found: The lookup stops here.
1279+
return {};
1280+
}
1281+
return r;
1282+
};
1283+
1284+
// Try to find the deepest values that have the desired types. If there is no
1285+
// such mapping, simply return the deepest values.
1286+
ValueVector desiredValue;
1287+
ValueVector current{from};
1288+
ValueVector lastNonMaterialization{from};
1289+
do {
1290+
// Store the current value if the types match.
1291+
bool match = TypeRange(ValueRange(current)) == desiredTypes;
1292+
if (skipPureTypeConversions) {
1293+
// Skip pure type conversions, if requested.
1294+
bool pureConversion = isPureTypeConversion(current);
1295+
match &= !pureConversion;
1296+
// Keep track of the last mapped value that was not a pure type
1297+
// conversion.
1298+
if (!pureConversion)
1299+
lastNonMaterialization = current;
1300+
}
1301+
if (match)
1302+
desiredValue = current;
1303+
1304+
// Lookup next value in the mapping.
1305+
ValueVector next = composedLookup(current);
1306+
if (next.empty())
1307+
break;
1308+
current = std::move(next);
1309+
} while (true);
1310+
1311+
// If the desired values were found use them, otherwise default to the leaf
1312+
// values. (Skip pure type conversions, if requested.)
1313+
if (!desiredTypes.empty())
1314+
return desiredValue;
1315+
if (skipPureTypeConversions)
1316+
return lastNonMaterialization;
1317+
return current;
12711318
}
12721319

12731320
ValueVector
@@ -1324,10 +1371,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
13241371
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
13251372

13261373
if (!currentTypeConverter) {
1327-
// The current pattern does not have a type converter. I.e., it does not
1328-
// distinguish between legal and illegal types. For each operand, simply
1329-
// pass through the most recently mapped values.
1330-
remapped.push_back(lookupOrDefault(operand));
1374+
// The current pattern does not have a type converter. Pass the most
1375+
// recently mapped values, excluding materializations. Materializations
1376+
// are intentionally excluded because their presence may depend on other
1377+
// patterns. Including materializations would make the lookup fragile
1378+
// and unpredictable.
1379+
remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1380+
/*skipPureTypeConversions=*/true));
13311381
continue;
13321382
}
13331383

@@ -1356,7 +1406,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
13561406
}
13571407

13581408
// Create a materialization for the most recently mapped values.
1359-
repl = lookupOrDefault(operand);
1409+
repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1410+
/*skipPureTypeConversions=*/true);
13601411
ValueRange castValues = buildUnresolvedMaterialization(
13611412
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
13621413
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1482,7 +1533,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14821533
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
14831534
origArg.getLoc(),
14841535
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1485-
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
1536+
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1537+
/*castOp=*/nullptr, /*isPureTypeConversion=*/false)
14861538
.front();
14871539
replaceUsesOfBlockArgument(origArg, mat, converter);
14881540
continue;
@@ -1523,7 +1575,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15231575
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
15241576
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
15251577
Type originalType, const TypeConverter *converter,
1526-
UnrealizedConversionCastOp *castOp) {
1578+
UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) {
15271579
assert((!originalType || kind == MaterializationKind::Target) &&
15281580
"original type is valid only for target materializations");
15291581
assert(TypeRange(inputs) != outputTypes &&
@@ -1535,6 +1587,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15351587
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
15361588
auto convertOp =
15371589
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1590+
if (isPureTypeConversion)
1591+
convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
15381592
if (!valuesToMap.empty())
15391593
mapping.map(valuesToMap, convertOp.getResults());
15401594
if (castOp)
@@ -1650,7 +1704,8 @@ void ConversionPatternRewriterImpl::replaceOp(
16501704
MaterializationKind::Source, computeInsertPoint(result),
16511705
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
16521706
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
1653-
currentTypeConverter);
1707+
currentTypeConverter, /*castOp=*/nullptr,
1708+
/*isPureTypeConversion=*/false);
16541709
continue;
16551710
}
16561711

@@ -2901,6 +2956,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
29012956
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
29022957
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
29032958

2959+
// Drop markers.
2960+
for (UnrealizedConversionCastOp castOp : remainingCastOps)
2961+
castOp->removeAttr(kPureTypeConversionMarker);
2962+
29042963
// Try to legalize all unresolved materializations.
29052964
if (config.buildMaterializations) {
29062965
IRRewriter rewriter(rewriterImpl.context, config.listener);

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() {
415415
%0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
416416
"test.invalid"(%0) : (f16) -> ()
417417
}
418+
419+
// -----
420+
421+
// CHECK-LABEL: func @test_lookup_without_converter
422+
// CHECK: %[[producer:.*]] = "test.valid_producer"() : () -> i16
423+
// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
424+
// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
425+
// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
426+
func.func @test_lookup_without_converter() {
427+
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
428+
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
429+
// Make sure that the second "replace_with_valid_consumer" lowering does not
430+
// lookup the materialization that was created for the above op.
431+
"test.replace_with_valid_consumer"(%0) : (i64) -> ()
432+
// expected-remark@+1 {{op 'func.return' is not legalizable}}
433+
return
434+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
21042104
Arguments<(ins Variadic<AnyType>)>;
21052105
def TestTypeProducerOp : TEST_Op<"type_producer">,
21062106
Results<(outs AnyType)>;
2107+
def TestValidProducerOp : TEST_Op<"valid_producer">,
2108+
Results<(outs AnyType)>;
2109+
def TestValidConsumerOp : TEST_Op<"valid_consumer">,
2110+
Arguments<(ins AnyType)>;
21072111
def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
21082112
Results<(outs AnyType)>;
21092113
def TestTypeConsumerOp : TEST_Op<"type_consumer">,

0 commit comments

Comments
 (0)