@@ -3100,6 +3100,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
3100
3100
// ===----------------------------------------------------------------------===//
3101
3101
// OperationConverter
3102
3102
// ===----------------------------------------------------------------------===//
3103
+
3103
3104
namespace {
3104
3105
enum OpConversionMode {
3105
3106
// / In this mode, the conversion will ignore failed conversions to allow
@@ -3117,6 +3118,13 @@ enum OpConversionMode {
3117
3118
} // namespace
3118
3119
3119
3120
namespace mlir {
3121
+
3122
+ // Predeclaration only.
3123
+ static void reconcileUnrealizedCasts (
3124
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3125
+ &castOps,
3126
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
3127
+
3120
3128
// This class converts operations to a given conversion target via a set of
3121
3129
// rewrite patterns. The conversion behaves differently depending on the
3122
3130
// conversion mode.
@@ -3264,18 +3272,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
3264
3272
// After a successful conversion, apply rewrites.
3265
3273
rewriterImpl.applyRewrites ();
3266
3274
3267
- // Gather all unresolved materializations.
3268
- SmallVector<UnrealizedConversionCastOp> allCastOps;
3269
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3270
- &materializations = rewriterImpl.unresolvedMaterializations ;
3271
- for (auto it : materializations)
3272
- allCastOps.push_back (it.first );
3273
-
3274
3275
// Reconcile all UnrealizedConversionCastOps that were inserted by the
3275
- // dialect conversion frameworks. (Not the one that were inserted by
3276
+ // dialect conversion frameworks. (Not the ones that were inserted by
3276
3277
// patterns.)
3278
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3279
+ &materializations = rewriterImpl.unresolvedMaterializations ;
3277
3280
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
3278
- reconcileUnrealizedCasts (allCastOps , &remainingCastOps);
3281
+ reconcileUnrealizedCasts (materializations , &remainingCastOps);
3279
3282
3280
3283
// Drop markers.
3281
3284
for (UnrealizedConversionCastOp castOp : remainingCastOps)
@@ -3303,20 +3306,19 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
3303
3306
// Reconcile Unrealized Casts
3304
3307
// ===----------------------------------------------------------------------===//
3305
3308
3306
- void mlir::reconcileUnrealizedCasts (
3307
- ArrayRef<UnrealizedConversionCastOp> castOps,
3309
+ // / Try to reconcile all given UnrealizedConversionCastOps and store the
3310
+ // / left-over ops in `remainingCastOps` (if provided). See documentation in
3311
+ // / DialectConversion.h for more details.
3312
+ // / The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3313
+ // / algorithm may visit an operand (or user) which is a cast op, but will not
3314
+ // / try to reconcile it if not in the filtered set.
3315
+ template <typename RangeT>
3316
+ static void reconcileUnrealizedCastsImpl (
3317
+ RangeT castOps,
3318
+ function_ref<bool (UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3308
3319
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3320
+ // A worklist of cast ops to process.
3309
3321
SetVector<UnrealizedConversionCastOp> worklist (llvm::from_range, castOps);
3310
- // This set is maintained only if `remainingCastOps` is provided.
3311
- DenseSet<Operation *> erasedOps;
3312
-
3313
- // Helper function that adds all operands to the worklist that are an
3314
- // unrealized_conversion_cast op result.
3315
- auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
3316
- for (Value v : castOp.getInputs ())
3317
- if (auto inputCastOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3318
- worklist.insert (inputCastOp);
3319
- };
3320
3322
3321
3323
// Helper function that return the unrealized_conversion_cast op that
3322
3324
// defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3337,39 +3339,110 @@ void mlir::reconcileUnrealizedCasts(
3337
3339
// Process ops in the worklist bottom-to-top.
3338
3340
while (!worklist.empty ()) {
3339
3341
UnrealizedConversionCastOp castOp = worklist.pop_back_val ();
3340
- if (castOp->use_empty ()) {
3341
- // DCE: If the op has no users, erase it. Add the operands to the
3342
- // worklist to find additional DCE opportunities.
3343
- enqueueOperands (castOp);
3344
- if (remainingCastOps)
3345
- erasedOps.insert (castOp.getOperation ());
3346
- castOp->erase ();
3347
- continue ;
3348
- }
3349
3342
3350
3343
// Traverse the chain of input cast ops to see if an op with the same
3351
3344
// input types can be found.
3352
3345
UnrealizedConversionCastOp nextCast = castOp;
3353
3346
while (nextCast) {
3354
3347
if (nextCast.getInputs ().getTypes () == castOp.getResultTypes ()) {
3348
+ if (llvm::any_of (nextCast.getInputs (), [&](Value v) {
3349
+ return v.getDefiningOp () == castOp;
3350
+ })) {
3351
+ // Ran into a cycle.
3352
+ break ;
3353
+ }
3354
+
3355
3355
// Found a cast where the input types match the output types of the
3356
- // matched op. We can directly use those inputs and the matched op can
3357
- // be removed.
3358
- enqueueOperands (castOp);
3356
+ // matched op. We can directly use those inputs.
3359
3357
castOp.replaceAllUsesWith (nextCast.getInputs ());
3360
- if (remainingCastOps)
3361
- erasedOps.insert (castOp.getOperation ());
3362
- castOp->erase ();
3363
3358
break ;
3364
3359
}
3365
3360
nextCast = getInputCast (nextCast);
3366
3361
}
3367
3362
}
3368
3363
3369
- if (remainingCastOps)
3370
- for (UnrealizedConversionCastOp op : castOps)
3371
- if (!erasedOps.contains (op.getOperation ()))
3364
+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
3365
+ // used by an op that is not a cast op.
3366
+ DenseSet<Operation *> liveOps;
3367
+
3368
+ // Helper function that marks the given op and transitively reachable input
3369
+ // cast ops as alive.
3370
+ auto markOpLive = [&](Operation *rootOp) {
3371
+ SmallVector<Operation *> worklist;
3372
+ worklist.push_back (rootOp);
3373
+ while (!worklist.empty ()) {
3374
+ Operation *op = worklist.pop_back_val ();
3375
+ if (liveOps.insert (op).second ) {
3376
+ // Successfully inserted: process reachable input cast ops.
3377
+ for (Value v : op->getOperands ())
3378
+ if (auto castOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3379
+ if (isCastOpOfInterestFn (castOp))
3380
+ worklist.push_back (castOp);
3381
+ }
3382
+ }
3383
+ };
3384
+
3385
+ // Find all alive cast ops.
3386
+ for (UnrealizedConversionCastOp op : castOps) {
3387
+ // The op may have been marked live already as being an operand of another
3388
+ // live cast op.
3389
+ if (liveOps.contains (op.getOperation ()))
3390
+ continue ;
3391
+ // If any of the users is not a cast op, mark the current op (and its
3392
+ // input ops) as live.
3393
+ if (llvm::any_of (op->getUsers (), [&](Operation *user) {
3394
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3395
+ return !castOp || !isCastOpOfInterestFn (castOp);
3396
+ }))
3397
+ markOpLive (op);
3398
+ }
3399
+
3400
+ // Erase all dead cast ops.
3401
+ for (UnrealizedConversionCastOp op : castOps) {
3402
+ if (liveOps.contains (op)) {
3403
+ // Op is alive and was not erased. Add it to the remaining cast ops.
3404
+ if (remainingCastOps)
3372
3405
remainingCastOps->push_back (op);
3406
+ continue ;
3407
+ }
3408
+
3409
+ // Op is dead. Erase it.
3410
+ op->dropAllUses ();
3411
+ op->erase ();
3412
+ }
3413
+ }
3414
+
3415
+ void mlir::reconcileUnrealizedCasts (
3416
+ ArrayRef<UnrealizedConversionCastOp> castOps,
3417
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3418
+ // Set of all cast ops for faster lookups.
3419
+ DenseSet<UnrealizedConversionCastOp> castOpSet;
3420
+ for (UnrealizedConversionCastOp op : castOps)
3421
+ castOpSet.insert (op);
3422
+ reconcileUnrealizedCasts (castOpSet, remainingCastOps);
3423
+ }
3424
+
3425
+ void mlir::reconcileUnrealizedCasts (
3426
+ const DenseSet<UnrealizedConversionCastOp> &castOps,
3427
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3428
+ reconcileUnrealizedCastsImpl (
3429
+ llvm::make_range (castOps.begin (), castOps.end ()),
3430
+ [&](UnrealizedConversionCastOp castOp) {
3431
+ return castOps.contains (castOp);
3432
+ },
3433
+ remainingCastOps);
3434
+ }
3435
+
3436
+ static void mlir::reconcileUnrealizedCasts (
3437
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3438
+ &castOps,
3439
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3440
+ reconcileUnrealizedCastsImpl (
3441
+ castOps.keys (),
3442
+ [&](UnrealizedConversionCastOp castOp) {
3443
+ return castOps.contains (castOp);
3444
+ },
3445
+ remainingCastOps);
3373
3446
}
3374
3447
3375
3448
// ===----------------------------------------------------------------------===//
0 commit comments