Skip to content

[mlir][Transforms] ConversionPatternRewriter: Add config getter #152310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
public:
~ConversionPatternRewriter() override;

/// Return the configuration of the current dialect conversion.
const ConversionConfig &getConfig() const;

/// Apply a signature conversion to given block. This replaces the block with
/// a new block containing the updated signature. The operations of the given
/// block are inlined into the newly-created block, which is returned.
Expand Down
23 changes: 11 additions & 12 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,10 @@ ConversionPatternRewriter::ConversionPatternRewriter(

ConversionPatternRewriter::~ConversionPatternRewriter() = default;

const ConversionConfig &ConversionPatternRewriter::getConfig() const {
return impl->config;
}

void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
assert(op && newOp && "expected non-null op");
replaceOp(op, newOp->getResults());
Expand Down Expand Up @@ -1895,7 +1899,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// ops should be moved one-by-one ("slow path"), so that a separate
// `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
// a bit more efficient, so we try to do that when possible.
bool fastPath = !impl->config.listener;
bool fastPath = !getConfig().listener;

if (fastPath)
impl->inlineBlockBefore(source, dest, before);
Expand Down Expand Up @@ -2018,8 +2022,7 @@ class OperationLegalizer {
using LegalizationAction = ConversionTarget::LegalizationAction;

OperationLegalizer(const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config);
const FrozenRewritePatternSet &patterns);

/// Returns true if the given operation is known to be illegal on the target.
bool isIllegal(Operation *op) const;
Expand Down Expand Up @@ -2116,16 +2119,12 @@ class OperationLegalizer {

/// The pattern applicator to use for conversions.
PatternApplicator applicator;

/// Dialect conversion configuration.
const ConversionConfig &config;
};
} // namespace

OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config)
: target(targetInfo), applicator(patterns), config(config) {
const FrozenRewritePatternSet &patterns)
: target(targetInfo), applicator(patterns) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
Expand Down Expand Up @@ -2286,7 +2285,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
newOp->getName()));
if (!config.allowPatternRollback) {
if (!rewriter.getConfig().allowPatternRollback) {
// Rolling back a folder is like rolling back a pattern.
llvm::report_fatal_error(
"op '" + opName +
Expand All @@ -2306,6 +2305,7 @@ LogicalResult
OperationLegalizer::legalizeWithPattern(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
const ConversionConfig &config = rewriter.getConfig();

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
Operation *checkOp;
Expand Down Expand Up @@ -2749,8 +2749,7 @@ struct OperationConverter {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
: config(config), opLegalizer(target, patterns, this->config),
mode(mode) {}
: config(config), opLegalizer(target, patterns), mode(mode) {}

/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
Expand Down
Loading