diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index cf577eca5b9a6..556e73c2d56c7 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -202,17 +202,62 @@ struct MyConversionPattern : public ConversionPattern { #### Type Safety -The types of the remapped operands provided to a conversion pattern must be of a -type expected by the pattern. The expected types of a pattern are determined by -a provided [TypeConverter](#type-converter). If no type converter is provided, -the types of the remapped operands are expected to match the types of the -original operands. If a type converter is provided, the types of the remapped -operands are expected to be legal as determined by the converter. If the -remapped operand types are not of an expected type, and a materialization to the -expected type could not be performed, the pattern fails application before the -`matchAndRewrite` hook is invoked. This ensures that patterns do not have to -explicitly ensure type safety, or sanitize the types of the incoming remapped -operands. More information on type conversion is detailed in the +The types of the remapped operands provided to a conversion pattern (through +the adaptor or `ArrayRef` of operands) depend on type conversion rules. + +If the pattern was initialized with a [type converter](#type-converter), the +conversion driver passes values whose types match the legalized types of the +operands of the matched operation as per the type converter. To that end, the +conversion driver may insert target materializations to convert the most +recently mapped values to the expected legalized types. The driver tries to +reuse existing materializations on a best-effort basis, but this is not +guaranteed by the infrastructure. If the operand types of the matched op could +not be legalized, the pattern fails to apply before the `matchAndRewrite` hook +is invoked. + +Example: +```c++ +// Type converter that converts all FloatTypes to IntegerTypes. +TypeConverter converter; +converter.addConversion([](FloatType t) { + return IntegerType::get(t.getContext(), t.getWidth()); +}); + +// Assuming that `MyConversionPattern` was initialized with `converter`. +struct MyConversionPattern : public ConversionPattern { + virtual LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, /* ... */) const { +// ^^^^^^^^ +// If `op` has a FloatType operand, the respective value in `operands` +// is guaranteed to have the legalized IntegerType. If another pattern +// previously replaced the operand SSA value with an SSA value of the +// legalized type (via "replaceOp" or "applySignatureConversion"), you +// will get that SSA value directly (unless the replacement value was +// also replaced). Otherwise, you will get a materialization to the +// legalized type. +``` + +If the pattern was initialized without a type converter, the conversion driver +passes the most recently mapped values to the pattern, excluding any +materializations. If a value with the same type as the original operand is +desired, users can directly take the respective operand from the matched +operation. + +Example: When initializing the pattern from the above example without a type +converter, `operands` contains the most recent replacement values, regardless +of their types. + +Note: When running without a type converter, materializations are intentionally +excluded from the lookup process because their presence may depend on other +patterns. Passing materializations would make the conversion infrastructure +fragile and unpredictable. Moreover, there could be multiple materializations +to different types. (This can be the case when multiple patterns are running +with different type converters.) In such a case, it would be unclear which +materialization to pass. + +The above rules ensure that patterns do not have to explicitly ensure type +safety, or sanitize the types of the incoming remapped operands. More +information on type conversion is detailed in the [dedicated section](#type-conversion) below. ## Type Conversion diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index f23c6197accd5..dedc84f1adde9 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -121,17 +121,8 @@ struct ConversionValueMapping { /// false positives. bool isMappedTo(Value value) const { return mappedTo.contains(value); } - /// Lookup the most recently mapped values with the desired types in the - /// mapping. - /// - /// Special cases: - /// - If the desired type range is empty, simply return the most recently - /// mapped values. - /// - If there is no mapping to the desired types, also return the most - /// recently mapped values. - /// - If there is no mapping for the given values at all, return the given - /// value. - ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + /// Lookup a value in the mapping. + ValueVector lookup(const ValueVector &from) const; template struct IsValueVector : std::is_same, ValueVector> {}; @@ -185,54 +176,31 @@ struct ConversionValueMapping { }; } // namespace -ValueVector -ConversionValueMapping::lookupOrDefault(Value from, - TypeRange desiredTypes) const { - // Try to find the deepest values that have the desired types. If there is no - // such mapping, simply return the deepest values. - ValueVector desiredValue; - ValueVector current{from}; - do { - // Store the current value if the types match. - if (TypeRange(ValueRange(current)) == desiredTypes) - desiredValue = current; - - // If possible, Replace each value with (one or multiple) mapped values. - ValueVector next; - for (Value v : current) { - auto it = mapping.find({v}); - if (it != mapping.end()) { - llvm::append_range(next, it->second); - } else { - next.push_back(v); - } - } - if (next != current) { - // If at least one value was replaced, continue the lookup from there. - current = std::move(next); - continue; - } - - // Otherwise: Check if there is a mapping for the entire vector. Such - // mappings are materializations. (N:M mapping are not supported for value - // replacements.) - // - // Note: From a correctness point of view, materializations do not have to - // be stored (and looked up) in the mapping. But for performance reasons, - // we choose to reuse existing IR (when possible) instead of creating it - // multiple times. - auto it = mapping.find(current); - if (it == mapping.end()) { - // No mapping found: The lookup stops here. - break; - } - current = it->second; - } while (true); +/// Marker attribute for pure type conversions. I.e., mappings whose only +/// purpose is to resolve a type mismatch. (In contrast, mappings that point to +/// the replacement values of a "replaceOp" call, etc., are not pure type +/// conversions.) +static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__"; + +/// A vector of values is a pure type conversion if all values are defined by +/// the same operation and the operation has the `kPureTypeConversionMarker` +/// attribute. +static bool isPureTypeConversion(const ValueVector &values) { + assert(!values.empty() && "expected non-empty value vector"); + Operation *op = values.front().getDefiningOp(); + for (Value v : llvm::drop_begin(values)) + if (v.getDefiningOp() != op) + return false; + return op && op->hasAttr(kPureTypeConversionMarker); +} - // If the desired values were found use them, otherwise default to the leaf - // values. - // Note: If `desiredTypes` is empty, this function always returns `current`. - return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); +ValueVector ConversionValueMapping::lookup(const ValueVector &from) const { + auto it = mapping.find(from); + if (it == mapping.end()) { + // No mapping found: The lookup stops here. + return {}; + } + return it->second; } //===----------------------------------------------------------------------===// @@ -930,7 +898,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// recently mapped values. /// - If there is no mapping for the given values at all, return the given /// value. - ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + /// + /// If `skipPureTypeConversions` is "true", materializations that are pure + /// type conversions are not considered. + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}, + bool skipPureTypeConversions = false) const; /// Lookup the given value within the map, or return an empty vector if the /// value is not mapped. If it is mapped, this follows the same behavior @@ -993,11 +965,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// If `valuesToMap` is set to a non-null Value, then that value is mapped to /// the results of the unresolved materialization in the conversion value /// mapping. + /// + /// If `isPureTypeConversion` is "true", the materialization is created only + /// to resolve a type mismatch. That means it is not a regular value + /// replacement issued by the user. (Replacement values that are created + /// "out of thin air" appear like unresolved materializations because they are + /// unrealized_conversion_cast ops. However, they must be treated like + /// regular value replacements.) ValueRange buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, - UnrealizedConversionCastOp *castOp = nullptr); + UnrealizedConversionCastOp *castOp = nullptr, + bool isPureTypeConversion = true); /// Find a replacement value for the given SSA value in the conversion value /// mapping. The replacement value must have the same type as the given SSA @@ -1264,10 +1244,77 @@ void ConversionPatternRewriterImpl::applyRewrites() { // State Management //===----------------------------------------------------------------------===// -ValueVector -ConversionPatternRewriterImpl::lookupOrDefault(Value from, - TypeRange desiredTypes) const { - return mapping.lookupOrDefault(from, desiredTypes); +ValueVector ConversionPatternRewriterImpl::lookupOrDefault( + Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const { + // Helper function that looks up each value in `values` individually and then + // composes the results. If that fails, it tries to look up the entire vector + // at once. + auto composedLookup = [&](const ValueVector &values) -> ValueVector { + // If possible, replace each value with (one or multiple) mapped values. + ValueVector next; + for (Value v : values) { + ValueVector r = mapping.lookup({v}); + if (!r.empty()) { + llvm::append_range(next, r); + } else { + next.push_back(v); + } + } + if (next != values) { + // At least one value was replaced. + return next; + } + + // Otherwise: Check if there is a mapping for the entire vector. Such + // mappings are materializations. (N:M mapping are not supported for value + // replacements.) + // + // Note: From a correctness point of view, materializations do not have to + // be stored (and looked up) in the mapping. But for performance reasons, + // we choose to reuse existing IR (when possible) instead of creating it + // multiple times. + ValueVector r = mapping.lookup(values); + if (r.empty()) { + // No mapping found: The lookup stops here. + return {}; + } + return r; + }; + + // Try to find the deepest values that have the desired types. If there is no + // such mapping, simply return the deepest values. + ValueVector desiredValue; + ValueVector current{from}; + ValueVector lastNonMaterialization{from}; + do { + // Store the current value if the types match. + bool match = TypeRange(ValueRange(current)) == desiredTypes; + if (skipPureTypeConversions) { + // Skip pure type conversions, if requested. + bool pureConversion = isPureTypeConversion(current); + match &= !pureConversion; + // Keep track of the last mapped value that was not a pure type + // conversion. + if (!pureConversion) + lastNonMaterialization = current; + } + if (match) + desiredValue = current; + + // Lookup next value in the mapping. + ValueVector next = composedLookup(current); + if (next.empty()) + break; + current = std::move(next); + } while (true); + + // If the desired values were found use them, otherwise default to the leaf + // values. (Skip pure type conversions, if requested.) + if (!desiredTypes.empty()) + return desiredValue; + if (skipPureTypeConversions) + return lastNonMaterialization; + return current; } ValueVector @@ -1324,10 +1371,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); if (!currentTypeConverter) { - // The current pattern does not have a type converter. I.e., it does not - // distinguish between legal and illegal types. For each operand, simply - // pass through the most recently mapped values. - remapped.push_back(lookupOrDefault(operand)); + // The current pattern does not have a type converter. Pass the most + // recently mapped values, excluding materializations. Materializations + // are intentionally excluded because their presence may depend on other + // patterns. Including materializations would make the lookup fragile + // and unpredictable. + remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{}, + /*skipPureTypeConversions=*/true)); continue; } @@ -1356,7 +1406,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( } // Create a materialization for the most recently mapped values. - repl = lookupOrDefault(operand); + repl = lookupOrDefault(operand, /*desiredTypes=*/{}, + /*skipPureTypeConversions=*/true); ValueRange castValues = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(repl), operandLoc, /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, @@ -1482,7 +1533,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), /*valuesToMap=*/{}, /*inputs=*/ValueRange(), - /*outputTypes=*/origArgType, /*originalType=*/Type(), converter) + /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, + /*castOp=*/nullptr, /*isPureTypeConversion=*/false) .front(); replaceUsesOfBlockArgument(origArg, mat, converter); continue; @@ -1523,7 +1575,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, - UnrealizedConversionCastOp *castOp) { + UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); assert(TypeRange(inputs) != outputTypes && @@ -1535,6 +1587,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); + if (isPureTypeConversion) + convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr()); if (!valuesToMap.empty()) mapping.map(valuesToMap, convertOp.getResults()); if (castOp) @@ -1650,7 +1704,8 @@ void ConversionPatternRewriterImpl::replaceOp( MaterializationKind::Source, computeInsertPoint(result), result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), /*outputTypes=*/result.getType(), /*originalType=*/Type(), - currentTypeConverter); + currentTypeConverter, /*castOp=*/nullptr, + /*isPureTypeConversion=*/false); continue; } @@ -2902,6 +2957,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { SmallVector remainingCastOps; reconcileUnrealizedCasts(allCastOps, &remainingCastOps); + // Drop markers. + for (UnrealizedConversionCastOp castOp : remainingCastOps) + castOp->removeAttr(kPureTypeConversionMarker); + // Try to legalize all unresolved materializations. if (config.buildMaterializations) { IRRewriter rewriter(rewriterImpl.context, config.listener); diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index e4406e60ffead..5630d1540e4d5 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() { %0 = "test.multiple_1_to_n_replacement"() : () -> (f16) "test.invalid"(%0) : (f16) -> () } + +// ----- + +// CHECK-LABEL: func @test_lookup_without_converter +// CHECK: %[[producer:.*]] = "test.valid_producer"() : () -> i16 +// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64 +// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> () +// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> () +func.func @test_lookup_without_converter() { + %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64) + "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> () + // Make sure that the second "replace_with_valid_consumer" lowering does not + // lookup the materialization that was created for the above op. + "test.replace_with_valid_consumer"(%0) : (i64) -> () + // expected-remark@+1 {{op 'func.return' is not legalizable}} + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 2eaad552a7a3a..843bd30a51aff 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>, Arguments<(ins Variadic)>; def TestTypeProducerOp : TEST_Op<"type_producer">, Results<(outs AnyType)>; +def TestValidProducerOp : TEST_Op<"valid_producer">, + Results<(outs AnyType)>; +def TestValidConsumerOp : TEST_Op<"valid_consumer">, + Arguments<(ins AnyType)>; def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">, Results<(outs AnyType)>; def TestTypeConsumerOp : TEST_Op<"type_consumer">, diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index eda618f5b09c6..7150401bdbdce 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1198,6 +1198,47 @@ class TestEraseOp : public ConversionPattern { } }; +/// Pattern that replaces test.replace_with_valid_producer with +/// test.valid_producer and the specified type. +class TestReplaceWithValidProducer : public ConversionPattern { +public: + TestReplaceWithValidProducer(MLIRContext *ctx) + : ConversionPattern("test.replace_with_valid_producer", 1, ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto attr = op->getAttrOfType("type"); + if (!attr) + return failure(); + rewriter.replaceOpWithNewOp(op, attr.getValue()); + return success(); + } +}; + +/// Pattern that replaces test.replace_with_valid_consumer with +/// test.valid_consumer. Can be used with and without a type converter. +class TestReplaceWithValidConsumer : public ConversionPattern { +public: + TestReplaceWithValidConsumer(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.replace_with_valid_consumer", 1, + ctx) {} + TestReplaceWithValidConsumer(MLIRContext *ctx) + : ConversionPattern("test.replace_with_valid_consumer", 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // with_converter present: pattern must have been initialized with a type + // converter. + // with_converter absent: pattern must have been initialized without a type + // converter. + if (op->hasAttr("with_converter") != static_cast(getTypeConverter())) + return failure(); + rewriter.replaceOpWithNewOp(op, operands[0]); + return success(); + } +}; + /// This pattern matches a test.convert_block_args op. It either: /// a) Duplicates all block arguments, /// b) or: drops all block arguments and replaces each with 2x the first @@ -1314,6 +1355,7 @@ struct TestTypeConverter : public TypeConverter { TestTypeConverter() { addConversion(convertType); addSourceMaterialization(materializeCast); + addTargetMaterialization(materializeCast); } static LogicalResult convertType(Type t, SmallVectorImpl &results) { @@ -1389,10 +1431,12 @@ struct TestLegalizePatternDriver TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore, TestUndoPropertiesModification, TestEraseOp, + TestReplaceWithValidProducer, TestReplaceWithValidConsumer, TestRepetitive1ToNConsumer>(&getContext()); patterns.add(&getContext(), converter); + TestBlockArgReplace, TestReplaceWithValidConsumer>( + &getContext(), converter); patterns.add(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1402,7 +1446,8 @@ struct TestLegalizePatternDriver ConversionTarget target(getContext()); target.addLegalOp(); target.addLegalOp(); + TerminatorOp, OneRegionOp, TestValidProducerOp, + TestValidConsumerOp>(); target.addLegalOp(OperationName("test.legal_op", &getContext())); target .addIllegalOp();