@@ -789,26 +789,13 @@ enum MaterializationKind {
789789 Source
790790};
791791
792- // / An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
793- // / op. Unresolved materializations are erased at the end of the dialect
794- // / conversion.
795- class UnresolvedMaterializationRewrite : public OperationRewrite {
792+ // / Helper class that stores metadata about an unresolved materialization.
793+ class UnresolvedMaterializationInfo {
796794public:
797- UnresolvedMaterializationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
798- UnrealizedConversionCastOp op,
799- const TypeConverter *converter,
800- MaterializationKind kind, Type originalType,
801- ValueVector mappedValues);
802-
803- static bool classof (const IRRewrite *rewrite) {
804- return rewrite->getKind () == Kind::UnresolvedMaterialization;
805- }
806-
807- void rollback () override ;
808-
809- UnrealizedConversionCastOp getOperation () const {
810- return cast<UnrealizedConversionCastOp>(op);
811- }
795+ UnresolvedMaterializationInfo () = default ;
796+ UnresolvedMaterializationInfo (const TypeConverter *converter,
797+ MaterializationKind kind, Type originalType)
798+ : converterAndKind(converter, kind), originalType(originalType) {}
812799
813800 // / Return the type converter of this materialization (which may be null).
814801 const TypeConverter *getConverter () const {
@@ -832,7 +819,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
832819 // / The original type of the SSA value. Only used for target
833820 // / materializations.
834821 Type originalType;
822+ };
823+
824+ // / An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
825+ // / op. Unresolved materializations fold away or are replaced with
826+ // / source/target materializations at the end of the dialect conversion.
827+ class UnresolvedMaterializationRewrite : public OperationRewrite {
828+ public:
829+ UnresolvedMaterializationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
830+ UnrealizedConversionCastOp op,
831+ ValueVector mappedValues)
832+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
833+ mappedValues (std::move(mappedValues)) {}
834+
835+ static bool classof (const IRRewrite *rewrite) {
836+ return rewrite->getKind () == Kind::UnresolvedMaterialization;
837+ }
838+
839+ void rollback () override ;
835840
841+ UnrealizedConversionCastOp getOperation () const {
842+ return cast<UnrealizedConversionCastOp>(op);
843+ }
844+
845+ private:
836846 // / The values in the conversion value mapping that are being replaced by the
837847 // / results of this unresolved materialization.
838848 ValueVector mappedValues;
@@ -1088,9 +1098,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
10881098 // / by the current pattern.
10891099 SetVector<Block *> patternInsertedBlocks;
10901100
1091- // / A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1092- // / to the corresponding rewrite objects.
1093- DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
1101+ // / A mapping for looking up metadata of unresolved materializations.
1102+ DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
10941103 unresolvedMaterializations;
10951104
10961105 // / The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1219,6 @@ void CreateOperationRewrite::rollback() {
12101219 op->erase ();
12111220}
12121221
1213- UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite (
1214- ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1215- const TypeConverter *converter, MaterializationKind kind, Type originalType,
1216- ValueVector mappedValues)
1217- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1218- converterAndKind(converter, kind), originalType(originalType),
1219- mappedValues(std::move(mappedValues)) {
1220- assert ((!originalType || kind == MaterializationKind::Target) &&
1221- " original type is valid only for target materializations" );
1222- rewriterImpl.unresolvedMaterializations [op] = this ;
1223- }
1224-
12251222void UnresolvedMaterializationRewrite::rollback () {
12261223 if (!mappedValues.empty ())
12271224 rewriterImpl.mapping .erase (mappedValues);
@@ -1510,8 +1507,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15101507 mapping.map (valuesToMap, convertOp.getResults ());
15111508 if (castOp)
15121509 *castOp = convertOp;
1513- appendRewrite<UnresolvedMaterializationRewrite>(
1514- convertOp, converter, kind, originalType, std::move (valuesToMap));
1510+ unresolvedMaterializations[convertOp] =
1511+ UnresolvedMaterializationInfo (converter, kind, originalType);
1512+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1513+ std::move (valuesToMap));
15151514 return convertOp.getResults ();
15161515}
15171516
@@ -2678,21 +2677,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
26782677
26792678static LogicalResult
26802679legalizeUnresolvedMaterialization (RewriterBase &rewriter,
2681- UnresolvedMaterializationRewrite *rewrite) {
2682- UnrealizedConversionCastOp op = rewrite-> getOperation ();
2680+ UnrealizedConversionCastOp op,
2681+ const UnresolvedMaterializationInfo &info) {
26832682 assert (!op.use_empty () &&
26842683 " expected that dead materializations have already been DCE'd" );
26852684 Operation::operand_range inputOperands = op.getOperands ();
26862685
26872686 // Try to materialize the conversion.
2688- if (const TypeConverter *converter = rewrite-> getConverter ()) {
2687+ if (const TypeConverter *converter = info. getConverter ()) {
26892688 rewriter.setInsertionPoint (op);
26902689 SmallVector<Value> newMaterialization;
2691- switch (rewrite-> getMaterializationKind ()) {
2690+ switch (info. getMaterializationKind ()) {
26922691 case MaterializationKind::Target:
26932692 newMaterialization = converter->materializeTargetConversion (
26942693 rewriter, op->getLoc (), op.getResultTypes (), inputOperands,
2695- rewrite-> getOriginalType ());
2694+ info. getOriginalType ());
26962695 break ;
26972696 case MaterializationKind::Source:
26982697 assert (op->getNumResults () == 1 && " expected single result" );
@@ -2767,7 +2766,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
27672766
27682767 // Gather all unresolved materializations.
27692768 SmallVector<UnrealizedConversionCastOp> allCastOps;
2770- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * >
2769+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo >
27712770 &materializations = rewriterImpl.unresolvedMaterializations ;
27722771 for (auto it : materializations)
27732772 allCastOps.push_back (it.first );
@@ -2784,7 +2783,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
27842783 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
27852784 auto it = materializations.find (castOp);
27862785 assert (it != materializations.end () && " inconsistent state" );
2787- if (failed (legalizeUnresolvedMaterialization (rewriter, it->second )))
2786+ if (failed (
2787+ legalizeUnresolvedMaterialization (rewriter, castOp, it->second )))
27882788 return failure ();
27892789 }
27902790 }
0 commit comments