diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index cd8d3ee0af72b..96dd14f142328 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -564,9 +564,13 @@ class OpBuilder : public Builder { /// Attempts to fold the given operation and places new results within /// `results`. Returns success if the operation was folded, failure otherwise. - /// If the fold was in-place, `results` will not be filled. + /// If the fold was in-place, `results` will not be filled. Optionally, newly + /// materialized constant operations can be returned to the caller. + /// /// Note: This function does not erase the operation on a successful fold. - LogicalResult tryFold(Operation *op, SmallVectorImpl &results); + LogicalResult + tryFold(Operation *op, SmallVectorImpl &results, + SmallVectorImpl *materializedConstants = nullptr); /// Creates a deep copy of the specified operation, remapping any operands /// that use values outside of the operation using the map that is provided diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 16bd8201ad50a..89102115cdc40 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName, return create(state); } -LogicalResult OpBuilder::tryFold(Operation *op, - SmallVectorImpl &results) { +LogicalResult +OpBuilder::tryFold(Operation *op, SmallVectorImpl &results, + SmallVectorImpl *materializedConstants) { assert(results.empty() && "expected empty results"); ResultRange opResults = op->getResults(); @@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op, for (Operation *cst : generatedConstants) insert(cst); + // Return materialized constant operations. + if (materializedConstants) + *materializedConstants = std::move(generatedConstants); + return success(); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 962207059c8aa..4d250329c6f45 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2090,8 +2090,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op, ConversionPatternRewriter &rewriter) { auto &rewriterImpl = rewriter.getImpl(); - RewriterState curState = rewriterImpl.getCurrentState(); - LLVM_DEBUG({ rewriterImpl.logger.startLine() << "* Fold {\n"; rewriterImpl.logger.indent(); @@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op, // Try to fold the operation. SmallVector replacementValues; + SmallVector newOps; rewriter.setInsertionPoint(op); - if (failed(rewriter.tryFold(op, replacementValues))) { + if (failed(rewriter.tryFold(op, replacementValues, &newOps))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); return failure(); } + // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. if (replacementValues.empty()) return legalize(op, rewriter); // Recursively legalize any new constant operations. - for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size(); - i != e; ++i) { - auto *createOp = - dyn_cast(rewriterImpl.rewrites[i].get()); - if (!createOp) - continue; - if (failed(legalize(createOp->getOperation(), rewriter))) { + for (Operation *newOp : newOps) { + if (failed(legalize(newOp, rewriter))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "failed to legalize generated constant '{0}'", - createOp->getOperation()->getName())); - rewriterImpl.resetState(curState); + newOp->getName())); + // Legalization failed: erase all materialized constants. + for (Operation *op : newOps) + rewriter.eraseOp(op); return failure(); } }