Skip to content

Commit 9818ca9

Browse files
make lookup unambiguous
1 parent dace67e commit 9818ca9

File tree

5 files changed

+253
-85
lines changed

5 files changed

+253
-85
lines changed

mlir/docs/DialectConversion.md

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,60 @@ struct MyConversionPattern : public ConversionPattern {
202202
203203
#### Type Safety
204204
205-
The types of the remapped operands provided to a conversion pattern must be of a
206-
type expected by the pattern. The expected types of a pattern are determined by
207-
a provided [TypeConverter](#type-converter). If no type converter is provided,
208-
the types of the remapped operands are expected to match the types of the
209-
original operands. If a type converter is provided, the types of the remapped
210-
operands are expected to be legal as determined by the converter. If the
211-
remapped operand types are not of an expected type, and a materialization to the
212-
expected type could not be performed, the pattern fails application before the
213-
`matchAndRewrite` hook is invoked. This ensures that patterns do not have to
214-
explicitly ensure type safety, or sanitize the types of the incoming remapped
215-
operands. More information on type conversion is detailed in the
205+
The types of the remapped operands provided to a conversion pattern (through
206+
the adaptor or `ArrayRef` of operands) depend on type conversion rules.
207+
208+
If the pattern was initialized with a [type converter](#type-converter), the
209+
conversion driver passes values whose types match the legalized types of the
210+
operands of the matched operation as per the type converter. To that end, the
211+
conversion driver may insert target materializations to convert the most
212+
recently mapped values to the expected legalized types. The driver tries to
213+
reuse existing materializations on a best-effort basis, but this is not
214+
guaranteed by the infrastructure. If the operand types of the matched op could
215+
not be legalized, the pattern fails to apply before the `matchAndRewrite` hook
216+
is invoked.
217+
218+
Example:
219+
```c++
220+
// Type converter that converts all FloatTypes to IntegerTypes.
221+
TypeConverter converter;
222+
converter.addConversion([](FloatType t) {
223+
return IntegerType::get(t.getContext(), t.getWidth());
224+
});
225+
226+
// Assuming that `MyConversionPattern` was initialized with `converter`.
227+
struct MyConversionPattern : public ConversionPattern {
228+
virtual LogicalResult
229+
matchAndRewrite(Operation *op, ArrayRef<Value> operands, /* ... */) const {
230+
// ^^^^^^^^
231+
// If `op` has a FloatType operand, the respective value in `operands`
232+
// is guaranteed to have the legalized IntegerType. If another pattern
233+
// previously replaced the operand SSA value with an SSA value of the
234+
// legalized type, you will get that SSA value directly (unless it was
235+
// replaced a second time). Otherwise, you will get a materialization
236+
// to the legalized type.
237+
```
238+
239+
If the pattern was initialized without a type converter, the conversion driver
240+
passes the most recently mapped values to the pattern, excluding any
241+
materializations. If a value with the same type as the original operand is
242+
desired, users can directly take the respective operand from the matched
243+
operation.
244+
245+
Example: In the above example, `operands` contains the SSA values from the most
246+
recent replacement value, regardless of its type.
247+
248+
Note: When running without a type converter, materializations are intentionally
249+
excluded from the lookup process because their presence may depend on other
250+
patterns. Passing materializations would make the conversion infrastructure
251+
fragile and unpredictable. Moreover, there could be multiple materializations
252+
to different types. (This can be the case when multiple patterns are running
253+
with different type converters.) In such a case, it would be unclear which
254+
materialization to pass.
255+
256+
The above rules ensure that patterns do not have to explicitly ensure type
257+
safety, or sanitize the types of the incoming remapped operands. More
258+
information on type conversion is detailed in the
216259
[dedicated section](#type-conversion) below.
217260

218261
## Type Conversion

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 131 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,8 @@ struct ConversionValueMapping {
121121
/// false positives.
122122
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
123123

124-
/// Lookup the most recently mapped values with the desired types in the
125-
/// mapping.
126-
///
127-
/// Special cases:
128-
/// - If the desired type range is empty, simply return the most recently
129-
/// mapped values.
130-
/// - If there is no mapping to the desired types, also return the most
131-
/// recently mapped values.
132-
/// - If there is no mapping for the given values at all, return the given
133-
/// value.
134-
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
124+
/// Lookup a value in the mapping.
125+
ValueVector lookup(const ValueVector &from) const;
135126

136127
template <typename T>
137128
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -185,54 +176,31 @@ struct ConversionValueMapping {
185176
};
186177
} // namespace
187178

188-
ValueVector
189-
ConversionValueMapping::lookupOrDefault(Value from,
190-
TypeRange desiredTypes) const {
191-
// Try to find the deepest values that have the desired types. If there is no
192-
// such mapping, simply return the deepest values.
193-
ValueVector desiredValue;
194-
ValueVector current{from};
195-
do {
196-
// Store the current value if the types match.
197-
if (TypeRange(ValueRange(current)) == desiredTypes)
198-
desiredValue = current;
199-
200-
// If possible, Replace each value with (one or multiple) mapped values.
201-
ValueVector next;
202-
for (Value v : current) {
203-
auto it = mapping.find({v});
204-
if (it != mapping.end()) {
205-
llvm::append_range(next, it->second);
206-
} else {
207-
next.push_back(v);
208-
}
209-
}
210-
if (next != current) {
211-
// If at least one value was replaced, continue the lookup from there.
212-
current = std::move(next);
213-
continue;
214-
}
215-
216-
// Otherwise: Check if there is a mapping for the entire vector. Such
217-
// mappings are materializations. (N:M mapping are not supported for value
218-
// replacements.)
219-
//
220-
// Note: From a correctness point of view, materializations do not have to
221-
// be stored (and looked up) in the mapping. But for performance reasons,
222-
// we choose to reuse existing IR (when possible) instead of creating it
223-
// multiple times.
224-
auto it = mapping.find(current);
225-
if (it == mapping.end()) {
226-
// No mapping found: The lookup stops here.
227-
break;
228-
}
229-
current = it->second;
230-
} while (true);
179+
/// Marker attribute for pure type conversions. I.e., mappings whose only
180+
/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
181+
/// the replacement values of a "replaceOp" call, etc., are not pure type
182+
/// conversions.)
183+
static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
184+
185+
/// A vector of values is a pure type conversion if all values are defined by
186+
/// the same operation and the operation has the `kPureTypeConversionMarker`
187+
/// attribute.
188+
static bool isPureTypeConversion(const ValueVector &values) {
189+
assert(!values.empty() && "expected non-empty value vector");
190+
Operation *op = values.front().getDefiningOp();
191+
for (Value v : llvm::drop_begin(values))
192+
if (v.getDefiningOp() != op)
193+
return false;
194+
return op && op->hasAttr(kPureTypeConversionMarker);
195+
}
231196

232-
// If the desired values were found use them, otherwise default to the leaf
233-
// values.
234-
// Note: If `desiredTypes` is empty, this function always returns `current`.
235-
return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
197+
ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
198+
auto it = mapping.find(from);
199+
if (it == mapping.end()) {
200+
// No mapping found: The lookup stops here.
201+
return {};
202+
}
203+
return it->second;
236204
}
237205

238206
//===----------------------------------------------------------------------===//
@@ -930,7 +898,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
930898
/// recently mapped values.
931899
/// - If there is no mapping for the given values at all, return the given
932900
/// value.
933-
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
901+
///
902+
/// If `skipPureTypeConversions` is "true", materializations that are pure
903+
/// type conversions are not considered.
904+
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
905+
bool skipPureTypeConversions = false) const;
934906

935907
/// Lookup the given value within the map, or return an empty vector if the
936908
/// value is not mapped. If it is mapped, this follows the same behavior
@@ -993,11 +965,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
993965
/// If `valuesToMap` is set to a non-null Value, then that value is mapped to
994966
/// the results of the unresolved materialization in the conversion value
995967
/// mapping.
968+
///
969+
/// If `isPureTypeConversion` is "true", the materialization is created only
970+
/// to resolve a type mismatch. That means it is not a regular value
971+
/// replacement issued by the user. (Replacement values that are created
972+
/// "out of thin air" appear like unresolved materializations because they are
973+
/// unrealized_conversion_cast ops. However, they must be treated like
974+
/// regular value replacements.)
996975
ValueRange buildUnresolvedMaterialization(
997976
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
998977
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
999978
Type originalType, const TypeConverter *converter,
1000-
UnrealizedConversionCastOp *castOp = nullptr);
979+
UnrealizedConversionCastOp *castOp = nullptr,
980+
bool isPureTypeConversion = true);
1001981

1002982
/// Find a replacement value for the given SSA value in the conversion value
1003983
/// mapping. The replacement value must have the same type as the given SSA
@@ -1264,10 +1244,77 @@ void ConversionPatternRewriterImpl::applyRewrites() {
12641244
// State Management
12651245
//===----------------------------------------------------------------------===//
12661246

1267-
ValueVector
1268-
ConversionPatternRewriterImpl::lookupOrDefault(Value from,
1269-
TypeRange desiredTypes) const {
1270-
return mapping.lookupOrDefault(from, desiredTypes);
1247+
ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
1248+
Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1249+
// Helper function that looks up each value in `values` individually and then
1250+
// composes the results. If that fails, it tries to look up the entire vector
1251+
// at once.
1252+
auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1253+
// If possible, replace each value with (one or multiple) mapped values.
1254+
ValueVector next;
1255+
for (Value v : values) {
1256+
ValueVector r = mapping.lookup({v});
1257+
if (!r.empty()) {
1258+
llvm::append_range(next, r);
1259+
} else {
1260+
next.push_back(v);
1261+
}
1262+
}
1263+
if (next != values) {
1264+
// At least one value was replaced.
1265+
return next;
1266+
}
1267+
1268+
// Otherwise: Check if there is a mapping for the entire vector. Such
1269+
// mappings are materializations. (N:M mapping are not supported for value
1270+
// replacements.)
1271+
//
1272+
// Note: From a correctness point of view, materializations do not have to
1273+
// be stored (and looked up) in the mapping. But for performance reasons,
1274+
// we choose to reuse existing IR (when possible) instead of creating it
1275+
// multiple times.
1276+
ValueVector r = mapping.lookup(values);
1277+
if (r.empty()) {
1278+
// No mapping found: The lookup stops here.
1279+
return {};
1280+
}
1281+
return r;
1282+
};
1283+
1284+
// Try to find the deepest values that have the desired types. If there is no
1285+
// such mapping, simply return the deepest values.
1286+
ValueVector desiredValue;
1287+
ValueVector current{from};
1288+
ValueVector lastNonMaterialization{from};
1289+
do {
1290+
// Store the current value if the types match.
1291+
bool match = TypeRange(ValueRange(current)) == desiredTypes;
1292+
if (skipPureTypeConversions) {
1293+
// Skip pure type conversions, if requested.
1294+
bool pureConversion = isPureTypeConversion(current);
1295+
match &= !pureConversion;
1296+
// Keep track of the last mapped value that was not a pure type
1297+
// conversion.
1298+
if (!pureConversion)
1299+
lastNonMaterialization = current;
1300+
}
1301+
if (match)
1302+
desiredValue = current;
1303+
1304+
// Lookup next value in the mapping.
1305+
ValueVector next = composedLookup(current);
1306+
if (next.empty())
1307+
break;
1308+
current = std::move(next);
1309+
} while (true);
1310+
1311+
// If the desired values were found use them, otherwise default to the leaf
1312+
// values. (Skip pure type conversions, if requested.)
1313+
if (!desiredTypes.empty())
1314+
return desiredValue;
1315+
if (skipPureTypeConversions)
1316+
return lastNonMaterialization;
1317+
return current;
12711318
}
12721319

12731320
ValueVector
@@ -1324,10 +1371,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
13241371
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
13251372

13261373
if (!currentTypeConverter) {
1327-
// The current pattern does not have a type converter. I.e., it does not
1328-
// distinguish between legal and illegal types. For each operand, simply
1329-
// pass through the most recently mapped values.
1330-
remapped.push_back(lookupOrDefault(operand));
1374+
// The current pattern does not have a type converter. Pass the most
1375+
// recently mapped values, excluding materializations. Materializations
1376+
// are intentionally excluded because their presence may depend on other
1377+
// patterns. Including materializations would make the lookup fragile
1378+
// and unpredictable.
1379+
remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1380+
/*skipPureTypeConversions=*/true));
13311381
continue;
13321382
}
13331383

@@ -1356,7 +1406,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
13561406
}
13571407

13581408
// Create a materialization for the most recently mapped values.
1359-
repl = lookupOrDefault(operand);
1409+
repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1410+
/*skipPureTypeConversions=*/true);
13601411
ValueRange castValues = buildUnresolvedMaterialization(
13611412
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
13621413
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1482,7 +1533,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14821533
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
14831534
origArg.getLoc(),
14841535
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1485-
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
1536+
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1537+
/*castOp=*/nullptr, /*isPureTypeConversion=*/false)
14861538
.front();
14871539
replaceUsesOfBlockArgument(origArg, mat, converter);
14881540
continue;
@@ -1523,7 +1575,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15231575
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
15241576
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
15251577
Type originalType, const TypeConverter *converter,
1526-
UnrealizedConversionCastOp *castOp) {
1578+
UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) {
15271579
assert((!originalType || kind == MaterializationKind::Target) &&
15281580
"original type is valid only for target materializations");
15291581
assert(TypeRange(inputs) != outputTypes &&
@@ -1535,6 +1587,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15351587
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
15361588
auto convertOp =
15371589
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1590+
if (isPureTypeConversion)
1591+
convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
15381592
if (!valuesToMap.empty())
15391593
mapping.map(valuesToMap, convertOp.getResults());
15401594
if (castOp)
@@ -1650,7 +1704,8 @@ void ConversionPatternRewriterImpl::replaceOp(
16501704
MaterializationKind::Source, computeInsertPoint(result),
16511705
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
16521706
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
1653-
currentTypeConverter);
1707+
currentTypeConverter, /*castOp=*/nullptr,
1708+
/*isPureTypeConversion=*/false);
16541709
continue;
16551710
}
16561711

@@ -2902,6 +2957,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
29022957
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
29032958
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
29042959

2960+
// Drop markers.
2961+
for (UnrealizedConversionCastOp castOp : remainingCastOps)
2962+
castOp->removeAttr(kPureTypeConversionMarker);
2963+
29052964
// Try to legalize all unresolved materializations.
29062965
if (config.buildMaterializations) {
29072966
IRRewriter rewriter(rewriterImpl.context, config.listener);

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() {
415415
%0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
416416
"test.invalid"(%0) : (f16) -> ()
417417
}
418+
419+
// -----
420+
421+
// CHECK-LABEL: func @test_lookup_without_converter
422+
// CHECK: %[[producer:.*]] = "test.valid_producer"() : () -> i16
423+
// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
424+
// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
425+
// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
426+
func.func @test_lookup_without_converter() {
427+
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
428+
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
429+
// Make sure that the second "replace_with_valid_consumer" lowering does not
430+
// lookup the materialization that was created for the above op.
431+
"test.replace_with_valid_consumer"(%0) : (i64) -> ()
432+
// expected-remark@+1 {{op 'func.return' is not legalizable}}
433+
return
434+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
21042104
Arguments<(ins Variadic<AnyType>)>;
21052105
def TestTypeProducerOp : TEST_Op<"type_producer">,
21062106
Results<(outs AnyType)>;
2107+
def TestValidProducerOp : TEST_Op<"valid_producer">,
2108+
Results<(outs AnyType)>;
2109+
def TestValidConsumerOp : TEST_Op<"valid_consumer">,
2110+
Arguments<(ins AnyType)>;
21072111
def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
21082112
Results<(outs AnyType)>;
21092113
def TestTypeConsumerOp : TEST_Op<"type_consumer">,

0 commit comments

Comments
 (0)