From db67d6aa9438f09f72c8574264e1d6c98bfc2d5f Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 12 Sep 2025 14:50:53 +0100 Subject: [PATCH] Revert "[mlir][Transforms] Fix crash in `reconcile-unrealized-casts` (#158067)" This reverts commit 03e3ce82b926a4c138e6e0bacfcd1d5572c3e380. --- .../mlir/Transforms/DialectConversion.h | 3 - .../Transforms/Utils/DialectConversion.cpp | 151 +++++------------- .../reconcile-unrealized-casts.mlir | 50 ------ ...assume-alignment-runtime-verification.mlir | 3 +- .../atomic-rmw-runtime-verification.mlir | 3 +- .../MemRef/store-runtime-verification.mlir | 3 +- 6 files changed, 42 insertions(+), 171 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f8caae3ce9995..a096f82a4cfd8 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1428,9 +1428,6 @@ struct ConversionConfig { /// /// In the above example, %0 can be used instead of %3 and all cast ops are /// folded away. -void reconcileUnrealizedCasts( - const DenseSet &castOps, - SmallVectorImpl *remainingCastOps = nullptr); void reconcileUnrealizedCasts( ArrayRef castOps, SmallVectorImpl *remainingCastOps = nullptr); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d53e1e78f2027..df9700f11200f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3100,7 +3100,6 @@ unsigned OperationLegalizer::applyCostModelToPatterns( //===----------------------------------------------------------------------===// // OperationConverter //===----------------------------------------------------------------------===// - namespace { enum OpConversionMode { /// In this mode, the conversion will ignore failed conversions to allow @@ -3118,13 +3117,6 @@ enum OpConversionMode { } // namespace namespace mlir { - -// Predeclaration only. -static void reconcileUnrealizedCasts( - const DenseMap - &castOps, - SmallVectorImpl *remainingCastOps); - // This class converts operations to a given conversion target via a set of // rewrite patterns. The conversion behaves differently depending on the // conversion mode. @@ -3272,13 +3264,18 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // After a successful conversion, apply rewrites. rewriterImpl.applyRewrites(); - // Reconcile all UnrealizedConversionCastOps that were inserted by the - // dialect conversion frameworks. (Not the ones that were inserted by - // patterns.) + // Gather all unresolved materializations. + SmallVector allCastOps; const DenseMap &materializations = rewriterImpl.unresolvedMaterializations; + for (auto it : materializations) + allCastOps.push_back(it.first); + + // Reconcile all UnrealizedConversionCastOps that were inserted by the + // dialect conversion frameworks. (Not the one that were inserted by + // patterns.) SmallVector remainingCastOps; - reconcileUnrealizedCasts(materializations, &remainingCastOps); + reconcileUnrealizedCasts(allCastOps, &remainingCastOps); // Drop markers. for (UnrealizedConversionCastOp castOp : remainingCastOps) @@ -3306,19 +3303,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // Reconcile Unrealized Casts //===----------------------------------------------------------------------===// -/// Try to reconcile all given UnrealizedConversionCastOps and store the -/// left-over ops in `remainingCastOps` (if provided). See documentation in -/// DialectConversion.h for more details. -/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the -/// algorithm may visit an operand (or user) which is a cast op, but will not -/// try to reconcile it if not in the filtered set. -template -static void reconcileUnrealizedCastsImpl( - RangeT castOps, - function_ref isCastOpOfInterestFn, +void mlir::reconcileUnrealizedCasts( + ArrayRef castOps, SmallVectorImpl *remainingCastOps) { - // A worklist of cast ops to process. SetVector worklist(llvm::from_range, castOps); + // This set is maintained only if `remainingCastOps` is provided. + DenseSet erasedOps; + + // Helper function that adds all operands to the worklist that are an + // unrealized_conversion_cast op result. + auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) { + for (Value v : castOp.getInputs()) + if (auto inputCastOp = v.getDefiningOp()) + worklist.insert(inputCastOp); + }; // Helper function that return the unrealized_conversion_cast op that // defines all inputs of the given op (in the same order). Return "nullptr" @@ -3339,110 +3337,39 @@ static void reconcileUnrealizedCastsImpl( // Process ops in the worklist bottom-to-top. while (!worklist.empty()) { UnrealizedConversionCastOp castOp = worklist.pop_back_val(); + if (castOp->use_empty()) { + // DCE: If the op has no users, erase it. Add the operands to the + // worklist to find additional DCE opportunities. + enqueueOperands(castOp); + if (remainingCastOps) + erasedOps.insert(castOp.getOperation()); + castOp->erase(); + continue; + } // Traverse the chain of input cast ops to see if an op with the same // input types can be found. UnrealizedConversionCastOp nextCast = castOp; while (nextCast) { if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) { - if (llvm::any_of(nextCast.getInputs(), [&](Value v) { - return v.getDefiningOp() == castOp; - })) { - // Ran into a cycle. - break; - } - // Found a cast where the input types match the output types of the - // matched op. We can directly use those inputs. + // matched op. We can directly use those inputs and the matched op can + // be removed. + enqueueOperands(castOp); castOp.replaceAllUsesWith(nextCast.getInputs()); + if (remainingCastOps) + erasedOps.insert(castOp.getOperation()); + castOp->erase(); break; } nextCast = getInputCast(nextCast); } } - // A set of all alive cast ops. I.e., ops whose results are (transitively) - // used by an op that is not a cast op. - DenseSet liveOps; - - // Helper function that marks the given op and transitively reachable input - // cast ops as alive. - auto markOpLive = [&](Operation *rootOp) { - SmallVector worklist; - worklist.push_back(rootOp); - while (!worklist.empty()) { - Operation *op = worklist.pop_back_val(); - if (liveOps.insert(op).second) { - // Successfully inserted: process reachable input cast ops. - for (Value v : op->getOperands()) - if (auto castOp = v.getDefiningOp()) - if (isCastOpOfInterestFn(castOp)) - worklist.push_back(castOp); - } - } - }; - - // Find all alive cast ops. - for (UnrealizedConversionCastOp op : castOps) { - // The op may have been marked live already as being an operand of another - // live cast op. - if (liveOps.contains(op.getOperation())) - continue; - // If any of the users is not a cast op, mark the current op (and its - // input ops) as live. - if (llvm::any_of(op->getUsers(), [&](Operation *user) { - auto castOp = dyn_cast(user); - return !castOp || !isCastOpOfInterestFn(castOp); - })) - markOpLive(op); - } - - // Erase all dead cast ops. - for (UnrealizedConversionCastOp op : castOps) { - if (liveOps.contains(op)) { - // Op is alive and was not erased. Add it to the remaining cast ops. - if (remainingCastOps) + if (remainingCastOps) + for (UnrealizedConversionCastOp op : castOps) + if (!erasedOps.contains(op.getOperation())) remainingCastOps->push_back(op); - continue; - } - - // Op is dead. Erase it. - op->dropAllUses(); - op->erase(); - } -} - -void mlir::reconcileUnrealizedCasts( - ArrayRef castOps, - SmallVectorImpl *remainingCastOps) { - // Set of all cast ops for faster lookups. - DenseSet castOpSet; - for (UnrealizedConversionCastOp op : castOps) - castOpSet.insert(op); - reconcileUnrealizedCasts(castOpSet, remainingCastOps); -} - -void mlir::reconcileUnrealizedCasts( - const DenseSet &castOps, - SmallVectorImpl *remainingCastOps) { - reconcileUnrealizedCastsImpl( - llvm::make_range(castOps.begin(), castOps.end()), - [&](UnrealizedConversionCastOp castOp) { - return castOps.contains(castOp); - }, - remainingCastOps); -} - -static void mlir::reconcileUnrealizedCasts( - const DenseMap - &castOps, - SmallVectorImpl *remainingCastOps) { - reconcileUnrealizedCastsImpl( - castOps.keys(), - [&](UnrealizedConversionCastOp castOp) { - return castOps.contains(castOp); - }, - remainingCastOps); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir index ac5ca321c066f..3573114f5e038 100644 --- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir +++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir @@ -194,53 +194,3 @@ func.func @emptyCast() -> index { %0 = builtin.unrealized_conversion_cast to index return %0 : index } - -// ----- - -// CHECK-LABEL: test.graph_region -// CHECK-NEXT: "test.return"() : () -> () -test.graph_region { - %0 = builtin.unrealized_conversion_cast %2 : i32 to i64 - %1 = builtin.unrealized_conversion_cast %0 : i64 to i16 - %2 = builtin.unrealized_conversion_cast %1 : i16 to i32 - "test.return"() : () -> () -} - -// ----- - -// CHECK-LABEL: test.graph_region -// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64 -// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16 -// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32 -// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> () -// CHECK-NEXT: "test.return"() : () -> () -test.graph_region { - %0 = builtin.unrealized_conversion_cast %2 : i32 to i64 - %1 = builtin.unrealized_conversion_cast %0 : i64 to i16 - %2 = builtin.unrealized_conversion_cast %1 : i16 to i32 - "test.user"(%2) : (i32) -> () - "test.return"() : () -> () -} - -// ----- - -// CHECK-LABEL: test.graph_region -// CHECK-NEXT: "test.return"() : () -> () -test.graph_region { - %0 = builtin.unrealized_conversion_cast %0 : i32 to i32 - "test.return"() : () -> () -} - -// ----- - -// CHECK-LABEL: test.graph_region -// CHECK-NEXT: %[[c0:.*]] = arith.constant -// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32 -// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> () -// CHECK-NEXT: "test.return"() : () -> () -test.graph_region { - %cst = arith.constant 0 : i32 - %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32 - "test.user"(%0) : (i32) -> () - "test.return"() : () -> () -} diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir index 01a826a638606..25a338df8d790 100644 --- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir @@ -1,8 +1,7 @@ // RUN: mlir-opt %s -generate-runtime-verification \ // RUN: -expand-strided-metadata \ // RUN: -test-cf-assert \ -// RUN: -convert-to-llvm \ -// RUN: -reconcile-unrealized-casts | \ +// RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir index 1144a7caf36e8..4c6a48d577a6c 100644 --- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir @@ -1,7 +1,6 @@ // RUN: mlir-opt %s -generate-runtime-verification \ // RUN: -test-cf-assert \ -// RUN: -convert-to-llvm \ -// RUN: -reconcile-unrealized-casts | \ +// RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir index 82e63805cd027..dd000c6904bcb 100644 --- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir @@ -1,7 +1,6 @@ // RUN: mlir-opt %s -generate-runtime-verification \ // RUN: -test-cf-assert \ -// RUN: -convert-to-llvm \ -// RUN: -reconcile-unrealized-casts | \ +// RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s