@@ -862,8 +862,11 @@ static bool hasRewrite(R &&rewrites, Block *block) {
862862// ===----------------------------------------------------------------------===//
863863// ConversionPatternRewriterImpl
864864// ===----------------------------------------------------------------------===//
865+
865866namespace mlir {
866867namespace detail {
868+ class OperationLegalizer ;
869+
867870struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
868871 explicit ConversionPatternRewriterImpl (ConversionPatternRewriter &rewriter,
869872 const ConversionConfig &config)
@@ -915,6 +918,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
915918 // / Return "true" if the given operation was replaced or erased.
916919 bool wasOpReplaced (Operation *op) const ;
917920
921+ // / Set the operation legalizer to use for recursive legalization.
922+ void setOperationLegalizer (OperationLegalizer *legalizer) {
923+ opLegalizer = legalizer;
924+ }
925+
918926 // / Lookup the most recently mapped values with the desired types in the
919927 // / mapping, taking into account only replacements. Perform a best-effort
920928 // / search for existing materializations with the desired types.
@@ -1121,6 +1129,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11211129 // / converting the arguments of blocks within that region.
11221130 DenseMap<Region *, const TypeConverter *> regionToConverter;
11231131
1132+ // / The operation legalizer to use for recursive legalization. This is set by
1133+ // / the OperationConverter when the rewriter is created.
1134+ OperationLegalizer *opLegalizer = nullptr ;
1135+
11241136 // / Dialect conversion configuration.
11251137 const ConversionConfig &config;
11261138
@@ -2357,7 +2369,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
23572369// OperationLegalizer
23582370// ===----------------------------------------------------------------------===//
23592371
2360- namespace {
2372+ namespace mlir ::detail {
23612373// / A set of rewrite patterns that can be used to legalize a given operation.
23622374using LegalizationPatterns = SmallVector<const Pattern *, 1 >;
23632375
@@ -2454,7 +2466,7 @@ class OperationLegalizer {
24542466 // / The pattern applicator to use for conversions.
24552467 PatternApplicator applicator;
24562468};
2457- } // namespace
2469+ } // namespace mlir::detail
24582470
24592471OperationLegalizer::OperationLegalizer (ConversionPatternRewriter &rewriter,
24602472 const ConversionTarget &targetInfo,
@@ -2854,6 +2866,41 @@ LogicalResult OperationLegalizer::legalizePatternRootUpdates(
28542866 return success ();
28552867}
28562868
2869+ LogicalResult ConversionPatternRewriter::legalize (Operation *op) {
2870+ return impl->opLegalizer ->legalize (op);
2871+ }
2872+
2873+ LogicalResult ConversionPatternRewriter::legalize (Region *r) {
2874+ // Fast path: If the region is empty, there is nothing to legalize.
2875+ if (r->empty ())
2876+ return success ();
2877+
2878+ // Gather a list of all operations to legalize. This is done before
2879+ // converting the entry block signature because unrealized_conversion_cast
2880+ // ops should not be included.
2881+ SmallVector<Operation *> ops;
2882+ for (Block &b : *r)
2883+ for (Operation &op : b)
2884+ ops.push_back (&op);
2885+
2886+ // If the current pattern runs with a type converter, convert the entry block
2887+ // signature.
2888+ if (const TypeConverter *converter = impl->currentTypeConverter ) {
2889+ std::optional<TypeConverter::SignatureConversion> conversion =
2890+ converter->convertBlockSignature (&r->front ());
2891+ if (!conversion)
2892+ return failure ();
2893+ applySignatureConversion (&r->front (), *conversion, converter);
2894+ }
2895+
2896+ // Legalize all operations in the region.
2897+ for (Operation *op : ops)
2898+ if (failed (legalize (op)))
2899+ return failure ();
2900+
2901+ return success ();
2902+ }
2903+
28572904// ===----------------------------------------------------------------------===//
28582905// Cost Model
28592906// ===----------------------------------------------------------------------===//
@@ -3218,7 +3265,10 @@ struct OperationConverter {
32183265 const ConversionConfig &config,
32193266 OpConversionMode mode)
32203267 : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
3221- mode(mode) {}
3268+ mode(mode) {
3269+ // Set the legalizer in the rewriter so patterns can recursively legalize.
3270+ rewriter.getImpl ().setOperationLegalizer (&opLegalizer);
3271+ }
32223272
32233273 // / Converts the given operations to the conversion target.
32243274 LogicalResult convertOperations (ArrayRef<Operation *> ops);
0 commit comments