From 88b91b47063e8b8aea0c6a002d853e24c774ab8f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 20 Apr 2025 15:37:56 +0200 Subject: [PATCH 1/2] [mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback --- mlir/include/mlir/IR/Builders.h | 8 +++++-- mlir/lib/IR/Builders.cpp | 9 ++++++-- .../Transforms/Utils/DialectConversion.cpp | 21 ++++++++----------- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index cd8d3ee0af72b..8f13705fac96d 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -564,9 +564,13 @@ class OpBuilder : public Builder { /// Attempts to fold the given operation and places new results within /// `results`. Returns success if the operation was folded, failure otherwise. - /// If the fold was in-place, `results` will not be filled. + /// If the fold was in-place, `results` will not be filled. Optionally, newly + /// materialized constant operations can be returned to the caller. + /// /// Note: This function does not erase the operation on a successful fold. - LogicalResult tryFold(Operation *op, SmallVectorImpl &results); + LogicalResult + tryFold(Operation *op, SmallVectorImpl &results, + SmallVector *materializedConstants = nullptr); /// Creates a deep copy of the specified operation, remapping any operands /// that use values outside of the operation using the map that is provided diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 16bd8201ad50a..9450ef7738fa0 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName, return create(state); } -LogicalResult OpBuilder::tryFold(Operation *op, - SmallVectorImpl &results) { +LogicalResult +OpBuilder::tryFold(Operation *op, SmallVectorImpl &results, + SmallVector *materializedConstants) { assert(results.empty() && "expected empty results"); ResultRange opResults = op->getResults(); @@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op, for (Operation *cst : generatedConstants) insert(cst); + // Return materialized constant operations. + if (materializedConstants) + *materializedConstants = std::move(generatedConstants); + return success(); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 962207059c8aa..3059b35865bf2 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2090,8 +2090,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op, ConversionPatternRewriter &rewriter) { auto &rewriterImpl = rewriter.getImpl(); - RewriterState curState = rewriterImpl.getCurrentState(); - LLVM_DEBUG({ rewriterImpl.logger.startLine() << "* Fold {\n"; rewriterImpl.logger.indent(); @@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op, // Try to fold the operation. SmallVector replacementValues; + SmallVector newOps; rewriter.setInsertionPoint(op); - if (failed(rewriter.tryFold(op, replacementValues))) { + if (failed(rewriter.tryFold(op, replacementValues, &newOps))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); return failure(); } + // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. if (replacementValues.empty()) return legalize(op, rewriter); // Recursively legalize any new constant operations. - for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size(); - i != e; ++i) { - auto *createOp = - dyn_cast(rewriterImpl.rewrites[i].get()); - if (!createOp) - continue; - if (failed(legalize(createOp->getOperation(), rewriter))) { + for (Operation *newOp : newOps) { + if (failed(legalize(newOp, rewriter))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "failed to legalize generated constant '{0}'", - createOp->getOperation()->getName())); - rewriterImpl.resetState(curState); + newOp->getName())); + // Legalization failed: erase all materialized constants. + for (Operation *op : newOps) + rewriter.eraseOp(op); return failure(); } } From 8bf5ca017d06c7a3f4fe9eca0aa512969912adcb Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 22 Apr 2025 08:55:20 +0200 Subject: [PATCH 2/2] address comments --- mlir/include/mlir/IR/Builders.h | 2 +- mlir/lib/IR/Builders.cpp | 2 +- mlir/lib/Transforms/Utils/DialectConversion.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 8f13705fac96d..96dd14f142328 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -570,7 +570,7 @@ class OpBuilder : public Builder { /// Note: This function does not erase the operation on a successful fold. LogicalResult tryFold(Operation *op, SmallVectorImpl &results, - SmallVector *materializedConstants = nullptr); + SmallVectorImpl *materializedConstants = nullptr); /// Creates a deep copy of the specified operation, remapping any operands /// that use values outside of the operation using the map that is provided diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 9450ef7738fa0..89102115cdc40 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -467,7 +467,7 @@ Operation *OpBuilder::create(Location loc, StringAttr opName, LogicalResult OpBuilder::tryFold(Operation *op, SmallVectorImpl &results, - SmallVector *materializedConstants) { + SmallVectorImpl *materializedConstants) { assert(results.empty() && "expected empty results"); ResultRange opResults = op->getResults(); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3059b35865bf2..4d250329c6f45 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2097,7 +2097,7 @@ OperationLegalizer::legalizeWithFold(Operation *op, // Try to fold the operation. SmallVector replacementValues; - SmallVector newOps; + SmallVector newOps; rewriter.setInsertionPoint(op); if (failed(rewriter.tryFold(op, replacementValues, &newOps))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));