Skip to content

Commit 0a72e6d

Browse files
[mlir][Transforms] ConversionPatternRewriter: Add config getter (#152310)
Add a helper function to `ConversionPatternRewriter` that returns the dialect conversion configuration. This flag is useful when migrating conversion patterns to the new One-Shot Conversion Driver: patterns can check if they are running in rollback mode or not. They can then work around API changes and makes sure that the pattern keeps working with both the old and new driver. Also remove the `config` field from `OperationLegalizer`. That field was never needed.
1 parent 0abf497 commit 0a72e6d

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
728728
public:
729729
~ConversionPatternRewriter() override;
730730

731+
/// Return the configuration of the current dialect conversion.
732+
const ConversionConfig &getConfig() const;
733+
731734
/// Apply a signature conversion to given block. This replaces the block with
732735
/// a new block containing the updated signature. The operations of the given
733736
/// block are inlined into the newly-created block, which is returned.

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,6 +1754,10 @@ ConversionPatternRewriter::ConversionPatternRewriter(
17541754

17551755
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
17561756

1757+
const ConversionConfig &ConversionPatternRewriter::getConfig() const {
1758+
return impl->config;
1759+
}
1760+
17571761
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
17581762
assert(op && newOp && "expected non-null op");
17591763
replaceOp(op, newOp->getResults());
@@ -1895,7 +1899,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
18951899
// ops should be moved one-by-one ("slow path"), so that a separate
18961900
// `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
18971901
// a bit more efficient, so we try to do that when possible.
1898-
bool fastPath = !impl->config.listener;
1902+
bool fastPath = !getConfig().listener;
18991903

19001904
if (fastPath)
19011905
impl->inlineBlockBefore(source, dest, before);
@@ -2018,8 +2022,7 @@ class OperationLegalizer {
20182022
using LegalizationAction = ConversionTarget::LegalizationAction;
20192023

20202024
OperationLegalizer(const ConversionTarget &targetInfo,
2021-
const FrozenRewritePatternSet &patterns,
2022-
const ConversionConfig &config);
2025+
const FrozenRewritePatternSet &patterns);
20232026

20242027
/// Returns true if the given operation is known to be illegal on the target.
20252028
bool isIllegal(Operation *op) const;
@@ -2116,16 +2119,12 @@ class OperationLegalizer {
21162119

21172120
/// The pattern applicator to use for conversions.
21182121
PatternApplicator applicator;
2119-
2120-
/// Dialect conversion configuration.
2121-
const ConversionConfig &config;
21222122
};
21232123
} // namespace
21242124

21252125
OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
2126-
const FrozenRewritePatternSet &patterns,
2127-
const ConversionConfig &config)
2128-
: target(targetInfo), applicator(patterns), config(config) {
2126+
const FrozenRewritePatternSet &patterns)
2127+
: target(targetInfo), applicator(patterns) {
21292128
// The set of patterns that can be applied to illegal operations to transform
21302129
// them into legal ones.
21312130
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2286,7 +2285,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
22862285
LLVM_DEBUG(logFailure(rewriterImpl.logger,
22872286
"failed to legalize generated constant '{0}'",
22882287
newOp->getName()));
2289-
if (!config.allowPatternRollback) {
2288+
if (!rewriter.getConfig().allowPatternRollback) {
22902289
// Rolling back a folder is like rolling back a pattern.
22912290
llvm::report_fatal_error(
22922291
"op '" + opName +
@@ -2306,6 +2305,7 @@ LogicalResult
23062305
OperationLegalizer::legalizeWithPattern(Operation *op,
23072306
ConversionPatternRewriter &rewriter) {
23082307
auto &rewriterImpl = rewriter.getImpl();
2308+
const ConversionConfig &config = rewriter.getConfig();
23092309

23102310
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
23112311
Operation *checkOp;
@@ -2749,8 +2749,7 @@ struct OperationConverter {
27492749
const FrozenRewritePatternSet &patterns,
27502750
const ConversionConfig &config,
27512751
OpConversionMode mode)
2752-
: config(config), opLegalizer(target, patterns, this->config),
2753-
mode(mode) {}
2752+
: config(config), opLegalizer(target, patterns), mode(mode) {}
27542753

27552754
/// Converts the given operations to the conversion target.
27562755
LogicalResult convertOperations(ArrayRef<Operation *> ops);

0 commit comments

Comments
 (0)