diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f6437657c9a93..4e651a0489899 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public PatternRewriter { public: ~ConversionPatternRewriter() override; + /// Return the configuration of the current dialect conversion. + const ConversionConfig &getConfig() const; + /// Apply a signature conversion to given block. This replaces the block with /// a new block containing the updated signature. The operations of the given /// block are inlined into the newly-created block, which is returned. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index f23c6197accd5..a55da79455010 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1754,6 +1754,10 @@ ConversionPatternRewriter::ConversionPatternRewriter( ConversionPatternRewriter::~ConversionPatternRewriter() = default; +const ConversionConfig &ConversionPatternRewriter::getConfig() const { + return impl->config; +} + void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { assert(op && newOp && "expected non-null op"); replaceOp(op, newOp->getResults()); @@ -1895,7 +1899,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, // ops should be moved one-by-one ("slow path"), so that a separate // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is // a bit more efficient, so we try to do that when possible. - bool fastPath = !impl->config.listener; + bool fastPath = !getConfig().listener; if (fastPath) impl->inlineBlockBefore(source, dest, before); @@ -2018,8 +2022,7 @@ class OperationLegalizer { using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(const ConversionTarget &targetInfo, - const FrozenRewritePatternSet &patterns, - const ConversionConfig &config); + const FrozenRewritePatternSet &patterns); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -2116,16 +2119,12 @@ class OperationLegalizer { /// The pattern applicator to use for conversions. PatternApplicator applicator; - - /// Dialect conversion configuration. - const ConversionConfig &config; }; } // namespace OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo, - const FrozenRewritePatternSet &patterns, - const ConversionConfig &config) - : target(targetInfo), applicator(patterns), config(config) { + const FrozenRewritePatternSet &patterns) + : target(targetInfo), applicator(patterns) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. DenseMap legalizerPatterns; @@ -2286,7 +2285,7 @@ OperationLegalizer::legalizeWithFold(Operation *op, LLVM_DEBUG(logFailure(rewriterImpl.logger, "failed to legalize generated constant '{0}'", newOp->getName())); - if (!config.allowPatternRollback) { + if (!rewriter.getConfig().allowPatternRollback) { // Rolling back a folder is like rolling back a pattern. llvm::report_fatal_error( "op '" + opName + @@ -2306,6 +2305,7 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op, ConversionPatternRewriter &rewriter) { auto &rewriterImpl = rewriter.getImpl(); + const ConversionConfig &config = rewriter.getConfig(); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS Operation *checkOp; @@ -2749,8 +2749,7 @@ struct OperationConverter { const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode) - : config(config), opLegalizer(target, patterns, this->config), - mode(mode) {} + : config(config), opLegalizer(target, patterns), mode(mode) {} /// Converts the given operations to the conversion target. LogicalResult convertOperations(ArrayRef ops);