Skip to content

Commit da82d72

Browse files
[mlir][Transforms] Fix crash in reconcile-unrealized-casts (#158298)
The `reconcile-unrealized-casts` pass used to crash when the input contains circular chains of `unrealized_conversion_cast` ops. Furthermore, the `reconcileUnrealizedCasts` helper functions used to erase ops that were not passed via the `castOps` operand. Such ops are now preserved. That's why some integration tests had to be changed. Also avoid copying the set of all unresolved materializations in `convertOperations`. This commit is in preparation of turning `RewriterBase::replaceOp` into a non-virtual function. This is a re-upload of #158067, which was reverted due to CI failures. Note for LLVM integration: If you are seeing tests that are failing with `error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast`, you may have to add the `-reconcile-unrealized-casts` pass to your pass pipeline. (Or switch to the `-convert-to-llvm` pass instead of combining the various `-convert-*-to-llvm` passes.) --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent acd0899 commit da82d72

File tree

7 files changed

+173
-42
lines changed

7 files changed

+173
-42
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,9 @@ struct ConversionConfig {
14281428
///
14291429
/// In the above example, %0 can be used instead of %3 and all cast ops are
14301430
/// folded away.
1431+
void reconcileUnrealizedCasts(
1432+
const DenseSet<UnrealizedConversionCastOp> &castOps,
1433+
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
14311434
void reconcileUnrealizedCasts(
14321435
ArrayRef<UnrealizedConversionCastOp> castOps,
14331436
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 112 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3100,6 +3100,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
31003100
//===----------------------------------------------------------------------===//
31013101
// OperationConverter
31023102
//===----------------------------------------------------------------------===//
3103+
31033104
namespace {
31043105
enum OpConversionMode {
31053106
/// In this mode, the conversion will ignore failed conversions to allow
@@ -3117,6 +3118,13 @@ enum OpConversionMode {
31173118
} // namespace
31183119

31193120
namespace mlir {
3121+
3122+
// Predeclaration only.
3123+
static void reconcileUnrealizedCasts(
3124+
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3125+
&castOps,
3126+
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
3127+
31203128
// This class converts operations to a given conversion target via a set of
31213129
// rewrite patterns. The conversion behaves differently depending on the
31223130
// conversion mode.
@@ -3264,18 +3272,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
32643272
// After a successful conversion, apply rewrites.
32653273
rewriterImpl.applyRewrites();
32663274

3267-
// Gather all unresolved materializations.
3268-
SmallVector<UnrealizedConversionCastOp> allCastOps;
3269-
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3270-
&materializations = rewriterImpl.unresolvedMaterializations;
3271-
for (auto it : materializations)
3272-
allCastOps.push_back(it.first);
3273-
32743275
// Reconcile all UnrealizedConversionCastOps that were inserted by the
3275-
// dialect conversion frameworks. (Not the one that were inserted by
3276+
// dialect conversion frameworks. (Not the ones that were inserted by
32763277
// patterns.)
3278+
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3279+
&materializations = rewriterImpl.unresolvedMaterializations;
32773280
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
3278-
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
3281+
reconcileUnrealizedCasts(materializations, &remainingCastOps);
32793282

32803283
// Drop markers.
32813284
for (UnrealizedConversionCastOp castOp : remainingCastOps)
@@ -3303,20 +3306,19 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
33033306
// Reconcile Unrealized Casts
33043307
//===----------------------------------------------------------------------===//
33053308

3306-
void mlir::reconcileUnrealizedCasts(
3307-
ArrayRef<UnrealizedConversionCastOp> castOps,
3309+
/// Try to reconcile all given UnrealizedConversionCastOps and store the
3310+
/// left-over ops in `remainingCastOps` (if provided). See documentation in
3311+
/// DialectConversion.h for more details.
3312+
/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3313+
/// algorithm may visit an operand (or user) which is a cast op, but will not
3314+
/// try to reconcile it if not in the filtered set.
3315+
template <typename RangeT>
3316+
static void reconcileUnrealizedCastsImpl(
3317+
RangeT castOps,
3318+
function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
33083319
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3320+
// A worklist of cast ops to process.
33093321
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3310-
// This set is maintained only if `remainingCastOps` is provided.
3311-
DenseSet<Operation *> erasedOps;
3312-
3313-
// Helper function that adds all operands to the worklist that are an
3314-
// unrealized_conversion_cast op result.
3315-
auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
3316-
for (Value v : castOp.getInputs())
3317-
if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3318-
worklist.insert(inputCastOp);
3319-
};
33203322

33213323
// Helper function that return the unrealized_conversion_cast op that
33223324
// defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3337,39 +3339,110 @@ void mlir::reconcileUnrealizedCasts(
33373339
// Process ops in the worklist bottom-to-top.
33383340
while (!worklist.empty()) {
33393341
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3340-
if (castOp->use_empty()) {
3341-
// DCE: If the op has no users, erase it. Add the operands to the
3342-
// worklist to find additional DCE opportunities.
3343-
enqueueOperands(castOp);
3344-
if (remainingCastOps)
3345-
erasedOps.insert(castOp.getOperation());
3346-
castOp->erase();
3347-
continue;
3348-
}
33493342

33503343
// Traverse the chain of input cast ops to see if an op with the same
33513344
// input types can be found.
33523345
UnrealizedConversionCastOp nextCast = castOp;
33533346
while (nextCast) {
33543347
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3348+
if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3349+
return v.getDefiningOp() == castOp;
3350+
})) {
3351+
// Ran into a cycle.
3352+
break;
3353+
}
3354+
33553355
// Found a cast where the input types match the output types of the
3356-
// matched op. We can directly use those inputs and the matched op can
3357-
// be removed.
3358-
enqueueOperands(castOp);
3356+
// matched op. We can directly use those inputs.
33593357
castOp.replaceAllUsesWith(nextCast.getInputs());
3360-
if (remainingCastOps)
3361-
erasedOps.insert(castOp.getOperation());
3362-
castOp->erase();
33633358
break;
33643359
}
33653360
nextCast = getInputCast(nextCast);
33663361
}
33673362
}
33683363

3369-
if (remainingCastOps)
3370-
for (UnrealizedConversionCastOp op : castOps)
3371-
if (!erasedOps.contains(op.getOperation()))
3364+
// A set of all alive cast ops. I.e., ops whose results are (transitively)
3365+
// used by an op that is not a cast op.
3366+
DenseSet<Operation *> liveOps;
3367+
3368+
// Helper function that marks the given op and transitively reachable input
3369+
// cast ops as alive.
3370+
auto markOpLive = [&](Operation *rootOp) {
3371+
SmallVector<Operation *> worklist;
3372+
worklist.push_back(rootOp);
3373+
while (!worklist.empty()) {
3374+
Operation *op = worklist.pop_back_val();
3375+
if (liveOps.insert(op).second) {
3376+
// Successfully inserted: process reachable input cast ops.
3377+
for (Value v : op->getOperands())
3378+
if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3379+
if (isCastOpOfInterestFn(castOp))
3380+
worklist.push_back(castOp);
3381+
}
3382+
}
3383+
};
3384+
3385+
// Find all alive cast ops.
3386+
for (UnrealizedConversionCastOp op : castOps) {
3387+
// The op may have been marked live already as being an operand of another
3388+
// live cast op.
3389+
if (liveOps.contains(op.getOperation()))
3390+
continue;
3391+
// If any of the users is not a cast op, mark the current op (and its
3392+
// input ops) as live.
3393+
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3394+
auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3395+
return !castOp || !isCastOpOfInterestFn(castOp);
3396+
}))
3397+
markOpLive(op);
3398+
}
3399+
3400+
// Erase all dead cast ops.
3401+
for (UnrealizedConversionCastOp op : castOps) {
3402+
if (liveOps.contains(op)) {
3403+
// Op is alive and was not erased. Add it to the remaining cast ops.
3404+
if (remainingCastOps)
33723405
remainingCastOps->push_back(op);
3406+
continue;
3407+
}
3408+
3409+
// Op is dead. Erase it.
3410+
op->dropAllUses();
3411+
op->erase();
3412+
}
3413+
}
3414+
3415+
void mlir::reconcileUnrealizedCasts(
3416+
ArrayRef<UnrealizedConversionCastOp> castOps,
3417+
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3418+
// Set of all cast ops for faster lookups.
3419+
DenseSet<UnrealizedConversionCastOp> castOpSet;
3420+
for (UnrealizedConversionCastOp op : castOps)
3421+
castOpSet.insert(op);
3422+
reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3423+
}
3424+
3425+
void mlir::reconcileUnrealizedCasts(
3426+
const DenseSet<UnrealizedConversionCastOp> &castOps,
3427+
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3428+
reconcileUnrealizedCastsImpl(
3429+
llvm::make_range(castOps.begin(), castOps.end()),
3430+
[&](UnrealizedConversionCastOp castOp) {
3431+
return castOps.contains(castOp);
3432+
},
3433+
remainingCastOps);
3434+
}
3435+
3436+
static void mlir::reconcileUnrealizedCasts(
3437+
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3438+
&castOps,
3439+
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3440+
reconcileUnrealizedCastsImpl(
3441+
castOps.keys(),
3442+
[&](UnrealizedConversionCastOp castOp) {
3443+
return castOps.contains(castOp);
3444+
},
3445+
remainingCastOps);
33733446
}
33743447

33753448
//===----------------------------------------------------------------------===//

mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,53 @@ func.func @emptyCast() -> index {
194194
%0 = builtin.unrealized_conversion_cast to index
195195
return %0 : index
196196
}
197+
198+
// -----
199+
200+
// CHECK-LABEL: test.graph_region
201+
// CHECK-NEXT: "test.return"() : () -> ()
202+
test.graph_region {
203+
%0 = builtin.unrealized_conversion_cast %2 : i32 to i64
204+
%1 = builtin.unrealized_conversion_cast %0 : i64 to i16
205+
%2 = builtin.unrealized_conversion_cast %1 : i16 to i32
206+
"test.return"() : () -> ()
207+
}
208+
209+
// -----
210+
211+
// CHECK-LABEL: test.graph_region
212+
// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64
213+
// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16
214+
// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32
215+
// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> ()
216+
// CHECK-NEXT: "test.return"() : () -> ()
217+
test.graph_region {
218+
%0 = builtin.unrealized_conversion_cast %2 : i32 to i64
219+
%1 = builtin.unrealized_conversion_cast %0 : i64 to i16
220+
%2 = builtin.unrealized_conversion_cast %1 : i16 to i32
221+
"test.user"(%2) : (i32) -> ()
222+
"test.return"() : () -> ()
223+
}
224+
225+
// -----
226+
227+
// CHECK-LABEL: test.graph_region
228+
// CHECK-NEXT: "test.return"() : () -> ()
229+
test.graph_region {
230+
%0 = builtin.unrealized_conversion_cast %0 : i32 to i32
231+
"test.return"() : () -> ()
232+
}
233+
234+
// -----
235+
236+
// CHECK-LABEL: test.graph_region
237+
// CHECK-NEXT: %[[c0:.*]] = arith.constant
238+
// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32
239+
// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> ()
240+
// CHECK-NEXT: "test.return"() : () -> ()
241+
test.graph_region {
242+
%cst = arith.constant 0 : i32
243+
%0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
244+
"test.user"(%0) : (i32) -> ()
245+
"test.return"() : () -> ()
246+
}

mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// RUN: mlir-opt %s -generate-runtime-verification \
22
// RUN: -expand-strided-metadata \
33
// RUN: -test-cf-assert \
4-
// RUN: -convert-to-llvm | \
4+
// RUN: -convert-to-llvm \
5+
// RUN: -reconcile-unrealized-casts | \
56
// RUN: mlir-runner -e main -entry-point-result=void \
67
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
78
// RUN: FileCheck %s

mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt %s -generate-runtime-verification \
22
// RUN: -test-cf-assert \
3-
// RUN: -convert-to-llvm | \
3+
// RUN: -convert-to-llvm \
4+
// RUN: -reconcile-unrealized-casts | \
45
// RUN: mlir-runner -e main -entry-point-result=void \
56
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
67
// RUN: FileCheck %s

mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt %s -generate-runtime-verification \
22
// RUN: -test-cf-assert \
3-
// RUN: -convert-to-llvm | \
3+
// RUN: -convert-to-llvm \
4+
// RUN: -reconcile-unrealized-casts | \
45
// RUN: mlir-runner -e main -entry-point-result=void \
56
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
67
// RUN: FileCheck %s

mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1414
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
1515
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
16+
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
1617
#include "mlir/Dialect/Func/IR/FuncOps.h"
1718
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1819
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -73,6 +74,7 @@ void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
7374
opt.kernelBarePtrCallConv = true;
7475
opt.kernelIntersperseSizeCallConv = true;
7576
passManager.addPass(createGpuToLLVMConversionPass(opt));
77+
passManager.addPass(createReconcileUnrealizedCastsPass());
7678
}
7779

7880
} // namespace

0 commit comments

Comments
 (0)