diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 65e279e046e88..45ad6f8586daa 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -138,7 +138,8 @@ class TypeConverter { }; /// Register a conversion function. A conversion function must be convertible - /// to any of the following forms(where `T` is a class derived from `Type`: + /// to any of the following forms (where `T` is a class derived from `Type`): + /// /// * std::optional(T) /// - This form represents a 1-1 type conversion. It should return nullptr /// or `std::nullopt` to signify failure. If `std::nullopt` is returned, @@ -151,15 +152,7 @@ class TypeConverter { /// existing value are expected to be removed during conversion. If /// `std::nullopt` is returned, the converter is allowed to try another /// conversion function to perform the conversion. - /// * std::optional(T, SmallVectorImpl &, - /// ArrayRef) - /// - This form represents a 1-N type conversion supporting recursive - /// types. The first two arguments and the return value are the same as - /// for the regular 1-N form. The third argument is contains is the - /// "call stack" of the recursive conversion: it contains the list of - /// types currently being converted, with the current type being the - /// last one. If it is present more than once in the list, the - /// conversion concerns a recursive type. + /// /// Note: When attempting to convert a type, e.g. via 'convertType', the /// mostly recently added conversions will be invoked first. template >::template arg_t<1>> void addTargetMaterialization(FnT &&callback) { targetMaterializations.emplace_back( - wrapMaterialization(std::forward(callback))); + wrapTargetMaterialization(std::forward(callback))); } /// Register a conversion function for attributes within types. Type @@ -303,21 +310,12 @@ class TypeConverter { /// `add*Materialization` for more information on the context for these /// methods. Value materializeArgumentConversion(OpBuilder &builder, Location loc, - Type resultType, - ValueRange inputs) const { - return materializeConversion(argumentMaterializations, builder, loc, - resultType, inputs); - } + Type resultType, ValueRange inputs) const; Value materializeSourceConversion(OpBuilder &builder, Location loc, - Type resultType, ValueRange inputs) const { - return materializeConversion(sourceMaterializations, builder, loc, - resultType, inputs); - } + Type resultType, ValueRange inputs) const; Value materializeTargetConversion(OpBuilder &builder, Location loc, - Type resultType, ValueRange inputs) const { - return materializeConversion(targetMaterializations, builder, loc, - resultType, inputs); - } + Type resultType, ValueRange inputs, + Type originalType = {}) const; /// Convert an attribute present `attr` from within the type `type` using /// the registered conversion functions. If no applicable conversion has been @@ -333,21 +331,23 @@ class TypeConverter { using ConversionCallbackFn = std::function( Type, SmallVectorImpl &)>; - /// The signature of the callback used to materialize a conversion. + /// The signature of the callback used to materialize a source/argument + /// conversion. + /// + /// Arguments: builder, result type, inputs, location using MaterializationCallbackFn = std::function( OpBuilder &, Type, ValueRange, Location)>; + /// The signature of the callback used to materialize a target conversion. + /// + /// Arguments: builder, result type, inputs, location, original type + using TargetMaterializationCallbackFn = std::function( + OpBuilder &, Type, ValueRange, Location, Type)>; + /// The signature of the callback used to convert a type attribute. using TypeAttributeConversionCallbackFn = std::function; - /// Attempt to materialize a conversion using one of the provided - /// materialization functions. - Value - materializeConversion(ArrayRef materializations, - OpBuilder &builder, Location loc, Type resultType, - ValueRange inputs) const; - /// Generate a wrapper for the given callback. This allows for accepting /// different callback forms, that all compose into a single version. /// With callback of form: `std::optional(T)` @@ -388,9 +388,10 @@ class TypeConverter { cachedMultiConversions.clear(); } - /// Generate a wrapper for the given materialization callback. The callback - /// may take any subclass of `Type` and the wrapper will check for the target - /// type to be of the expected class before calling the callback. + /// Generate a wrapper for the given argument/source materialization + /// callback. The callback may take any subclass of `Type` and the + /// wrapper will check for the target type to be of the expected class + /// before calling the callback. template MaterializationCallbackFn wrapMaterialization(FnT &&callback) const { return [callback = std::forward(callback)]( @@ -402,6 +403,41 @@ class TypeConverter { }; } + /// Generate a wrapper for the given target materialization callback. + /// The callback may take any subclass of `Type` and the wrapper will check + /// for the target type to be of the expected class before calling the + /// callback. + /// + /// With callback of form: + /// `Value(OpBuilder &, T, ValueRange, Location, Type)` + template + std::enable_if_t< + std::is_invocable_v, + TargetMaterializationCallbackFn> + wrapTargetMaterialization(FnT &&callback) const { + return [callback = std::forward(callback)]( + OpBuilder &builder, Type resultType, ValueRange inputs, + Location loc, Type originalType) -> std::optional { + if (T derivedType = dyn_cast(resultType)) + return callback(builder, derivedType, inputs, loc, originalType); + return std::nullopt; + }; + } + /// With callback of form: + /// `Value(OpBuilder &, T, ValueRange, Location)` + template + std::enable_if_t< + std::is_invocable_v, + TargetMaterializationCallbackFn> + wrapTargetMaterialization(FnT &&callback) const { + return wrapTargetMaterialization( + [callback = std::forward(callback)]( + OpBuilder &builder, T resultType, ValueRange inputs, Location loc, + Type originalType) -> std::optional { + return callback(builder, resultType, inputs, loc); + }); + } + /// Generate a wrapper for the given memory space conversion callback. The /// callback may take any subclass of `Attribute` and the wrapper will check /// for the target attribute to be of the expected class before calling the @@ -434,7 +470,7 @@ class TypeConverter { /// The list of registered materialization functions. SmallVector argumentMaterializations; SmallVector sourceMaterializations; - SmallVector targetMaterializations; + SmallVector targetMaterializations; /// The list of registered type attribute conversion functions. SmallVector typeAttributeConversions; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 97dd3ab1f4829..1baddd881f6aa 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -683,10 +683,10 @@ enum MaterializationKind { /// conversion. class UnresolvedMaterializationRewrite : public OperationRewrite { public: - UnresolvedMaterializationRewrite( - ConversionPatternRewriterImpl &rewriterImpl, - UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr, - MaterializationKind kind = MaterializationKind::Target); + UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl, + UnrealizedConversionCastOp op, + const TypeConverter *converter, + MaterializationKind kind, Type originalType); static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::UnresolvedMaterialization; @@ -708,11 +708,18 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { return converterAndKind.getInt(); } + /// Return the original type of the SSA value. + Type getOriginalType() const { return originalType; } + private: /// The corresponding type converter to use when resolving this /// materialization, and the kind of this materialization. llvm::PointerIntPair converterAndKind; + + /// The original type of the SSA value. Only used for target + /// materializations. + Type originalType; }; } // namespace @@ -808,6 +815,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { Value buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueRange inputs, Type outputType, + Type originalType, const TypeConverter *converter); //===--------------------------------------------------------------------===// @@ -1034,9 +1042,12 @@ void CreateOperationRewrite::rollback() { UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, - const TypeConverter *converter, MaterializationKind kind) + const TypeConverter *converter, MaterializationKind kind, Type originalType) : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), - converterAndKind(converter, kind) { + converterAndKind(converter, kind), originalType(originalType) { + assert(!originalType || + kind == MaterializationKind::Target && + "original type is valid only for target materializations"); rewriterImpl.unresolvedMaterializations[op] = this; } @@ -1139,7 +1150,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( Value castValue = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(newOperand), operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType, - currentTypeConverter); + /*originalType=*/origType, currentTypeConverter); mapping.map(newOperand, castValue); newOperand = castValue; } @@ -1255,7 +1266,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( MaterializationKind::Source, OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), /*inputs=*/ValueRange(), - /*outputType=*/origArgType, converter); + /*outputType=*/origArgType, /*originalType=*/Type(), converter); mapping.map(origArg, repl); appendRewrite(block, origArg); continue; @@ -1280,7 +1291,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( Value argMat = buildUnresolvedMaterialization( MaterializationKind::Argument, OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*inputs=*/replArgs, origArgType, converter); + /*inputs=*/replArgs, /*outputType=*/origArgType, + /*originalType=*/Type(), converter); mapping.map(origArg, argMat); Type legalOutputType; @@ -1299,7 +1311,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( if (legalOutputType && legalOutputType != origArgType) { Value targetMat = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(argMat), - origArg.getLoc(), argMat, legalOutputType, converter); + origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType, + /*originalType=*/origArgType, converter); mapping.map(argMat, targetMat); } appendRewrite(block, origArg); @@ -1322,7 +1335,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// of input operands. Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - ValueRange inputs, Type outputType, const TypeConverter *converter) { + ValueRange inputs, Type outputType, Type originalType, + const TypeConverter *converter) { + assert(!originalType || + kind == MaterializationKind::Target && + "original type is valid only for target materializations"); + // Avoid materializing an unnecessary cast. if (inputs.size() == 1 && inputs.front().getType() == outputType) return inputs.front(); @@ -1333,7 +1351,8 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = builder.create(loc, outputType, inputs); - appendRewrite(convertOp, converter, kind); + appendRewrite(convertOp, converter, kind, + originalType); return convertOp.getResult(0); } @@ -1381,7 +1400,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, newValue = buildUnresolvedMaterialization( MaterializationKind::Source, computeInsertPoint(result), result.getLoc(), /*inputs=*/ValueRange(), - /*outputType=*/result.getType(), currentTypeConverter); + /*outputType=*/result.getType(), /*originalType=*/Type(), + currentTypeConverter); } // Remap, and check for any result type changes. @@ -2408,7 +2428,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, [[fallthrough]]; case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( - rewriter, op->getLoc(), outputType, inputOperands); + rewriter, op->getLoc(), outputType, inputOperands, + rewrite->getOriginalType()); break; case MaterializationKind::Source: newMaterialization = converter->materializeSourceConversion( @@ -2565,7 +2586,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) { MaterializationKind::Source, computeInsertPoint(newValue), originalValue.getLoc(), /*inputs=*/newValue, /*outputType=*/originalValue.getType(), - converter); + /*originalType=*/Type(), converter); rewriterImpl.mapping.map(originalValue, castValue); inverseMapping[castValue].push_back(originalValue); llvm::erase(inverseMapping[newValue], originalValue); @@ -2787,15 +2808,39 @@ TypeConverter::convertSignatureArgs(TypeRange types, return success(); } -Value TypeConverter::materializeConversion( - ArrayRef materializations, OpBuilder &builder, - Location loc, Type resultType, ValueRange inputs) const { - for (const MaterializationCallbackFn &fn : llvm::reverse(materializations)) +Value TypeConverter::materializeArgumentConversion(OpBuilder &builder, + Location loc, + Type resultType, + ValueRange inputs) const { + for (const MaterializationCallbackFn &fn : + llvm::reverse(argumentMaterializations)) + if (std::optional result = fn(builder, resultType, inputs, loc)) + return *result; + return nullptr; +} + +Value TypeConverter::materializeSourceConversion(OpBuilder &builder, + Location loc, Type resultType, + ValueRange inputs) const { + for (const MaterializationCallbackFn &fn : + llvm::reverse(sourceMaterializations)) if (std::optional result = fn(builder, resultType, inputs, loc)) return *result; return nullptr; } +Value TypeConverter::materializeTargetConversion(OpBuilder &builder, + Location loc, Type resultType, + ValueRange inputs, + Type originalType) const { + for (const TargetMaterializationCallbackFn &fn : + llvm::reverse(targetMaterializations)) + if (std::optional result = + fn(builder, resultType, inputs, loc, originalType)) + return *result; + return nullptr; +} + std::optional TypeConverter::convertBlockSignature(Block *block) const { SignatureConversion conversion(block->getNumArguments());