@@ -683,10 +683,10 @@ enum MaterializationKind {
683683// / conversion.
684684class UnresolvedMaterializationRewrite : public OperationRewrite {
685685public:
686- UnresolvedMaterializationRewrite (
687- ConversionPatternRewriterImpl &rewriterImpl ,
688- UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr ,
689- MaterializationKind kind = MaterializationKind::Target );
686+ UnresolvedMaterializationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
687+ UnrealizedConversionCastOp op ,
688+ const TypeConverter *converter,
689+ MaterializationKind kind, Type originalType );
690690
691691 static bool classof (const IRRewrite *rewrite) {
692692 return rewrite->getKind () == Kind::UnresolvedMaterialization;
@@ -708,11 +708,18 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
708708 return converterAndKind.getInt ();
709709 }
710710
711+ // / Return the original type of the SSA value.
712+ Type getOriginalType () const { return originalType; }
713+
711714private:
712715 // / The corresponding type converter to use when resolving this
713716 // / materialization, and the kind of this materialization.
714717 llvm::PointerIntPair<const TypeConverter *, 2 , MaterializationKind>
715718 converterAndKind;
719+
720+ // / The original type of the SSA value. Only used for target
721+ // / materializations.
722+ Type originalType;
716723};
717724} // namespace
718725
@@ -808,6 +815,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
808815 Value buildUnresolvedMaterialization (MaterializationKind kind,
809816 OpBuilder::InsertPoint ip, Location loc,
810817 ValueRange inputs, Type outputType,
818+ Type originalType,
811819 const TypeConverter *converter);
812820
813821 // ===--------------------------------------------------------------------===//
@@ -1034,9 +1042,12 @@ void CreateOperationRewrite::rollback() {
10341042
10351043UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite (
10361044 ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1037- const TypeConverter *converter, MaterializationKind kind)
1045+ const TypeConverter *converter, MaterializationKind kind, Type originalType )
10381046 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1039- converterAndKind(converter, kind) {
1047+ converterAndKind(converter, kind), originalType(originalType) {
1048+ assert (!originalType ||
1049+ kind == MaterializationKind::Target &&
1050+ " original type is valid only for target materializations" );
10401051 rewriterImpl.unresolvedMaterializations [op] = this ;
10411052}
10421053
@@ -1139,7 +1150,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11391150 Value castValue = buildUnresolvedMaterialization (
11401151 MaterializationKind::Target, computeInsertPoint (newOperand),
11411152 operandLoc, /* inputs=*/ newOperand, /* outputType=*/ desiredType,
1142- currentTypeConverter);
1153+ /* originalType= */ origType, currentTypeConverter);
11431154 mapping.map (newOperand, castValue);
11441155 newOperand = castValue;
11451156 }
@@ -1255,7 +1266,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12551266 MaterializationKind::Source,
12561267 OpBuilder::InsertPoint (newBlock, newBlock->begin ()), origArg.getLoc (),
12571268 /* inputs=*/ ValueRange (),
1258- /* outputType=*/ origArgType, converter);
1269+ /* outputType=*/ origArgType, /* originalType= */ Type (), converter);
12591270 mapping.map (origArg, repl);
12601271 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
12611272 continue ;
@@ -1280,7 +1291,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12801291 Value argMat = buildUnresolvedMaterialization (
12811292 MaterializationKind::Argument,
12821293 OpBuilder::InsertPoint (newBlock, newBlock->begin ()), origArg.getLoc (),
1283- /* inputs=*/ replArgs, origArgType, converter);
1294+ /* inputs=*/ replArgs, /* outputType=*/ origArgType,
1295+ /* originalType=*/ Type (), converter);
12841296 mapping.map (origArg, argMat);
12851297
12861298 Type legalOutputType;
@@ -1299,7 +1311,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12991311 if (legalOutputType && legalOutputType != origArgType) {
13001312 Value targetMat = buildUnresolvedMaterialization (
13011313 MaterializationKind::Target, computeInsertPoint (argMat),
1302- origArg.getLoc (), argMat, legalOutputType, converter);
1314+ origArg.getLoc (), /* inputs=*/ argMat, /* outputType=*/ legalOutputType,
1315+ /* originalType=*/ origArgType, converter);
13031316 mapping.map (argMat, targetMat);
13041317 }
13051318 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1322,7 +1335,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13221335// / of input operands.
13231336Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization (
13241337 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1325- ValueRange inputs, Type outputType, const TypeConverter *converter) {
1338+ ValueRange inputs, Type outputType, Type originalType,
1339+ const TypeConverter *converter) {
1340+ assert (!originalType ||
1341+ kind == MaterializationKind::Target &&
1342+ " original type is valid only for target materializations" );
1343+
13261344 // Avoid materializing an unnecessary cast.
13271345 if (inputs.size () == 1 && inputs.front ().getType () == outputType)
13281346 return inputs.front ();
@@ -1333,7 +1351,8 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13331351 builder.setInsertionPoint (ip.getBlock (), ip.getPoint ());
13341352 auto convertOp =
13351353 builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1336- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1354+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1355+ originalType);
13371356 return convertOp.getResult (0 );
13381357}
13391358
@@ -1381,7 +1400,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13811400 newValue = buildUnresolvedMaterialization (
13821401 MaterializationKind::Source, computeInsertPoint (result),
13831402 result.getLoc (), /* inputs=*/ ValueRange (),
1384- /* outputType=*/ result.getType (), currentTypeConverter);
1403+ /* outputType=*/ result.getType (), /* originalType=*/ Type (),
1404+ currentTypeConverter);
13851405 }
13861406
13871407 // Remap, and check for any result type changes.
@@ -2408,7 +2428,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
24082428 [[fallthrough]];
24092429 case MaterializationKind::Target:
24102430 newMaterialization = converter->materializeTargetConversion (
2411- rewriter, op->getLoc (), outputType, inputOperands);
2431+ rewriter, op->getLoc (), outputType, inputOperands,
2432+ rewrite->getOriginalType ());
24122433 break ;
24132434 case MaterializationKind::Source:
24142435 newMaterialization = converter->materializeSourceConversion (
@@ -2565,7 +2586,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
25652586 MaterializationKind::Source, computeInsertPoint (newValue),
25662587 originalValue.getLoc (),
25672588 /* inputs=*/ newValue, /* outputType=*/ originalValue.getType (),
2568- converter);
2589+ /* originalType= */ Type (), converter);
25692590 rewriterImpl.mapping .map (originalValue, castValue);
25702591 inverseMapping[castValue].push_back (originalValue);
25712592 llvm::erase (inverseMapping[newValue], originalValue);
@@ -2787,15 +2808,39 @@ TypeConverter::convertSignatureArgs(TypeRange types,
27872808 return success ();
27882809}
27892810
2790- Value TypeConverter::materializeConversion (
2791- ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
2792- Location loc, Type resultType, ValueRange inputs) const {
2793- for (const MaterializationCallbackFn &fn : llvm::reverse (materializations))
2811+ Value TypeConverter::materializeArgumentConversion (OpBuilder &builder,
2812+ Location loc,
2813+ Type resultType,
2814+ ValueRange inputs) const {
2815+ for (const MaterializationCallbackFn &fn :
2816+ llvm::reverse (argumentMaterializations))
2817+ if (std::optional<Value> result = fn (builder, resultType, inputs, loc))
2818+ return *result;
2819+ return nullptr ;
2820+ }
2821+
2822+ Value TypeConverter::materializeSourceConversion (OpBuilder &builder,
2823+ Location loc, Type resultType,
2824+ ValueRange inputs) const {
2825+ for (const MaterializationCallbackFn &fn :
2826+ llvm::reverse (sourceMaterializations))
27942827 if (std::optional<Value> result = fn (builder, resultType, inputs, loc))
27952828 return *result;
27962829 return nullptr ;
27972830}
27982831
2832+ Value TypeConverter::materializeTargetConversion (OpBuilder &builder,
2833+ Location loc, Type resultType,
2834+ ValueRange inputs,
2835+ Type originalType) const {
2836+ for (const TargetMaterializationCallbackFn &fn :
2837+ llvm::reverse (targetMaterializations))
2838+ if (std::optional<Value> result =
2839+ fn (builder, resultType, inputs, loc, originalType))
2840+ return *result;
2841+ return nullptr ;
2842+ }
2843+
27992844std::optional<TypeConverter::SignatureConversion>
28002845TypeConverter::convertBlockSignature (Block *block) const {
28012846 SignatureConversion conversion (block->getNumArguments ());
0 commit comments