From 7bee741431cb969cae16e5e94f0306c3319c09c4 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 26 Oct 2025 23:21:56 +0000 Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Convert entry block only --- .../Transforms/Utils/DialectConversion.cpp | 112 ++++-------------- .../test-legalizer-no-materializations.mlir | 67 +++++++++++ mlir/test/Transforms/test-legalizer.mlir | 39 ------ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 6 +- 4 files changed, 89 insertions(+), 135 deletions(-) create mode 100644 mlir/test/Transforms/test-legalizer-no-materializations.mlir diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3a23bbfd70eac..2fe06970eb568 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// A set of operations that were modified by the current pattern. SetVector patternModifiedOps; - /// A set of blocks that were inserted (newly-created blocks or moved blocks) - /// by the current pattern. - SetVector patternInsertedBlocks; - /// A list of unresolved materializations that were created by the current /// pattern. DenseSet patternMaterializations; @@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( if (!config.allowPatternRollback && config.listener) config.listener->notifyBlockInserted(block, previous, previousIt); - patternInsertedBlocks.insert(block); - if (wasDetached) { // If the block was detached, it is most likely a newly created block. if (config.allowPatternRollback) { @@ -2399,17 +2393,12 @@ class OperationLegalizer { bool canApplyPattern(Operation *op, const Pattern &pattern); /// Legalize the resultant IR after successfully applying the given pattern. - LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, - const RewriterState &curState, - const SetVector &newOps, - const SetVector &modifiedOps, - const SetVector &insertedBlocks); - - /// Legalizes the actions registered during the execution of a pattern. LogicalResult - legalizePatternBlockRewrites(Operation *op, - const SetVector &insertedBlocks, - const SetVector &newOps); + legalizePatternResult(Operation *op, const Pattern &pattern, + const RewriterState &curState, + const SetVector &newOps, + const SetVector &modifiedOps); + LogicalResult legalizePatternCreatedOperations(const SetVector &newOps); LogicalResult @@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) { auto cleanup = llvm::make_scope_exit([&]() { rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); - rewriterImpl.patternInsertedBlocks.clear(); }); // Upon failure, undo all changes made by the folder. @@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) { static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector &newOps, - const SetVector &modifiedOps, - const SetVector &insertedBlocks) { + const SetVector &modifiedOps) { auto newOpNames = llvm::map_range( newOps, [](Operation *op) { return op->getName().getStringRef(); }); auto modifiedOpNames = llvm::map_range( modifiedOps, [](Operation *op) { return op->getName().getStringRef(); }); - StringRef detachedBlockStr = "(detached block)"; - auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) { - if (block->getParentOp()) - return block->getParentOp()->getName().getStringRef(); - return detachedBlockStr; - }); - llvm::report_fatal_error( - "pattern '" + pattern.getDebugName() + - "' produced IR that could not be legalized. " + "new ops: {" + - llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" + - llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" + - llvm::join(insertedBlockNames, ", ") + "}"); + llvm::report_fatal_error("pattern '" + pattern.getDebugName() + + "' produced IR that could not be legalized. " + + "new ops: {" + llvm::join(newOpNames, ", ") + "}, " + + "modified ops: {" + + llvm::join(modifiedOpNames, ", ") + "}"); } LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { @@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { } rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); - rewriterImpl.patternInsertedBlocks.clear(); LLVM_DEBUG({ logFailure(rewriterImpl.logger, "pattern failed to match"); if (rewriterImpl.config.notifyCallback) { @@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { SetVector newOps = moveAndReset(rewriterImpl.patternNewOps); SetVector modifiedOps = moveAndReset(rewriterImpl.patternModifiedOps); - SetVector insertedBlocks = - moveAndReset(rewriterImpl.patternInsertedBlocks); - auto result = legalizePatternResult(op, pattern, curState, newOps, - modifiedOps, insertedBlocks); + auto result = + legalizePatternResult(op, pattern, curState, newOps, modifiedOps); appliedPatterns.erase(&pattern); if (failed(result)) { if (!rewriterImpl.config.allowPatternRollback) - reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps, - insertedBlocks); + reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps); rewriterImpl.resetState(curState, pattern.getDebugName()); } if (config.listener) @@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, LogicalResult OperationLegalizer::legalizePatternResult( Operation *op, const Pattern &pattern, const RewriterState &curState, const SetVector &newOps, - const SetVector &modifiedOps, - const SetVector &insertedBlocks) { + const SetVector &modifiedOps) { [[maybe_unused]] auto &impl = rewriter.getImpl(); assert(impl.pendingRootUpdates.empty() && "dangling root updates"); @@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult( #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. - if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) || - failed(legalizePatternRootUpdates(modifiedOps)) || + if (failed(legalizePatternRootUpdates(modifiedOps)) || failed(legalizePatternCreatedOperations(newOps))) { return failure(); } @@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult( return success(); } -LogicalResult OperationLegalizer::legalizePatternBlockRewrites( - Operation *op, const SetVector &insertedBlocks, - const SetVector &newOps) { - ConversionPatternRewriterImpl &impl = rewriter.getImpl(); - SmallPtrSet alreadyLegalized; - - // If the pattern moved or created any blocks, make sure the types of block - // arguments get legalized. - for (Block *block : insertedBlocks) { - if (impl.erasedBlocks.contains(block)) - continue; - - // Only check blocks outside of the current operation. - Operation *parentOp = block->getParentOp(); - if (!parentOp || parentOp == op || block->getNumArguments() == 0) - continue; - - // If the region of the block has a type converter, try to convert the block - // directly. - if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { - std::optional conversion = - converter->convertBlockSignature(block); - if (!conversion) { - LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " - "block")); - return failure(); - } - impl.applySignatureConversion(block, converter, *conversion); - continue; - } - - // Otherwise, try to legalize the parent operation if it was not generated - // by this pattern. This is because we will attempt to legalize the parent - // operation, and blocks in regions created by this pattern will already be - // legalized later on. - if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) { - if (failed(legalize(parentOp))) { - LLVM_DEBUG(logFailure( - impl.logger, "operation '{0}'({1}) became illegal after rewrite", - parentOp->getName(), parentOp)); - return failure(); - } - } - } - return success(); -} - LogicalResult OperationLegalizer::legalizePatternCreatedOperations( const SetVector &newOps) { for (Operation *op : newOps) { @@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, TypeConverter::SignatureConversion result(type.getNumInputs()); SmallVector newResults; if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || - failed(typeConverter.convertTypes(type.getResults(), newResults)) || - failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), - typeConverter, &result))) + failed(typeConverter.convertTypes(type.getResults(), newResults))) return failure(); + if (!funcOp.getFunctionBody().empty()) + rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result, + &typeConverter); // Update the function signature in-place. auto newType = FunctionType::get(rewriter.getContext(), diff --git a/mlir/test/Transforms/test-legalizer-no-materializations.mlir b/mlir/test/Transforms/test-legalizer-no-materializations.mlir new file mode 100644 index 0000000000000..82dd7422b22b2 --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-no-materializations.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND + +// CHECK-LABEL: func @dropped_input_in_use +// CHECK-KIND-LABEL: func @dropped_input_in_use +func.func @dropped_input_in_use(%arg: i16, %arg2: i64) { + // CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16 + // CHECK-NEXT: "work"(%[[cast]]) : (i16) + // CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"} + // CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16) + // expected-remark@+1 {{op 'work' is not legalizable}} + "work"(%arg) : (i16) -> () +} + +// ----- + +// CHECK-KIND-LABEL: func @test_lookup_without_converter +// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16 +// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"} +// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> () +// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> () +func.func @test_lookup_without_converter() { + %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64) + "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> () + // Make sure that the second "replace_with_valid_consumer" lowering does not + // lookup the materialization that was created for the above op. + "test.replace_with_valid_consumer"(%0) : (i64) -> () + // expected-remark@+1 {{op 'func.return' is not legalizable}} + return +} + +// ----- + +// CHECK-LABEL: func @remap_moved_region_args +func.func @remap_moved_region_args() { + // CHECK-NEXT: return + // CHECK-NEXT: ^bb1(%[[arg0:.*]]: i64, %[[arg1:.*]]: i16, %[[arg2:.*]]: i64, %[[arg3:.*]]: f32): + // CHECK-NEXT: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg3]] : f32 to f16, f16 + // CHECK-NEXT: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[arg2]] : i64 to f64 + // CHECK-NEXT: %[[cast3:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to f64 + // CHECK-NEXT: %[[cast4:.*]] = "test.cast"(%[[cast1]]#0, %[[cast1]]#1) : (f16, f16) -> f32 + // CHECK-NEXT: "test.valid"(%[[cast3]], %[[cast2]], %[[cast4]]) : (f64, f64, f32) + "test.region"() ({ + ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): + "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () + }) : () -> () + // expected-remark@+1 {{op 'func.return' is not legalizable}} + return +} + +// ----- + +// CHECK-LABEL: func @remap_cloned_region_args +func.func @remap_cloned_region_args() { + // CHECK-NEXT: return + // CHECK-NEXT: ^bb1(%[[arg0:.*]]: i64, %[[arg1:.*]]: i16, %[[arg2:.*]]: i64, %[[arg3:.*]]: f32): + // CHECK-NEXT: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg3]] : f32 to f16, f16 + // CHECK-NEXT: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[arg2]] : i64 to f64 + // CHECK-NEXT: %[[cast3:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to f64 + // CHECK-NEXT: %[[cast4:.*]] = "test.cast"(%[[cast1]]#0, %[[cast1]]#1) : (f16, f16) -> f32 + // CHECK-NEXT: "test.valid"(%[[cast3]], %[[cast2]], %[[cast4]]) : (f64, f64, f32) + "test.region"() ({ + ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): + "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () + }) {legalizer.should_clone} : () -> () + // expected-remark@+1 {{op 'func.return' is not legalizable}} + return +} diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 94c5bb4e93b06..7c43bb7bface0 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -1,7 +1,6 @@ // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND // CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B" // CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B" @@ -146,36 +145,6 @@ func.func @no_remap_nested() { // ----- -// CHECK-LABEL: func @remap_moved_region_args -func.func @remap_moved_region_args() { - // CHECK-NEXT: return - // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16): - // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32 - // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32) - "test.region"() ({ - ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): - "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () - }) : () -> () - // expected-remark@+1 {{op 'func.return' is not legalizable}} - return -} - -// ----- - -// CHECK-LABEL: func @remap_cloned_region_args -func.func @remap_cloned_region_args() { - // CHECK-NEXT: return - // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16): - // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32 - // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32) - "test.region"() ({ - ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): - "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () - }) {legalizer.should_clone} : () -> () - // expected-remark@+1 {{op 'func.return' is not legalizable}} - return -} - // CHECK-LABEL: func @remap_drop_region func.func @remap_drop_region() { // CHECK-NEXT: return @@ -191,12 +160,9 @@ func.func @remap_drop_region() { // ----- // CHECK-LABEL: func @dropped_input_in_use -// CHECK-KIND-LABEL: func @dropped_input_in_use func.func @dropped_input_in_use(%arg: i16, %arg2: i64) { // CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16 // CHECK-NEXT: "work"(%[[cast]]) : (i16) - // CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"} - // CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16) // expected-remark@+1 {{op 'work' is not legalizable}} "work"(%arg) : (i16) -> () } @@ -452,11 +418,6 @@ func.func @test_multiple_1_to_n_replacement() { // CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64 // CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> () // CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> () -// CHECK-KIND-LABEL: func @test_lookup_without_converter -// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16 -// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"} -// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> () -// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> () func.func @test_lookup_without_converter() { %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64) "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> () diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index fd2b943ff1296..12edecc113495 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1553,8 +1553,7 @@ struct TestLegalizePatternDriver [](Type type) { return type.isF32(); }); }); target.addDynamicallyLegalOp([&](func::FuncOp op) { - return converter.isSignatureLegal(op.getFunctionType()) && - converter.isLegal(&op.getBody()); + return converter.isSignatureLegal(op.getFunctionType()); }); target.addDynamicallyLegalOp( [&](func::CallOp op) { return converter.isLegal(op); }); @@ -2156,8 +2155,7 @@ struct TestTypeConversionDriver recursiveType.getName() == "outer_converted_type"); }); target.addDynamicallyLegalOp([&](func::FuncOp op) { - return converter.isSignatureLegal(op.getFunctionType()) && - converter.isLegal(&op.getBody()); + return converter.isSignatureLegal(op.getFunctionType()); }); target.addDynamicallyLegalOp([&](TestCastOp op) { // Allow casts from F64 to F32.