@@ -3100,6 +3100,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
31003100// ===----------------------------------------------------------------------===//
31013101// OperationConverter
31023102// ===----------------------------------------------------------------------===//
3103+
31033104namespace {
31043105enum OpConversionMode {
31053106 // / In this mode, the conversion will ignore failed conversions to allow
@@ -3117,6 +3118,13 @@ enum OpConversionMode {
31173118} // namespace
31183119
31193120namespace mlir {
3121+
3122+ // Predeclaration only.
3123+ static void reconcileUnrealizedCasts (
3124+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3125+ &castOps,
3126+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
3127+
31203128// This class converts operations to a given conversion target via a set of
31213129// rewrite patterns. The conversion behaves differently depending on the
31223130// conversion mode.
@@ -3264,18 +3272,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
32643272 // After a successful conversion, apply rewrites.
32653273 rewriterImpl.applyRewrites ();
32663274
3267- // Gather all unresolved materializations.
3268- SmallVector<UnrealizedConversionCastOp> allCastOps;
3269- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3270- &materializations = rewriterImpl.unresolvedMaterializations ;
3271- for (auto it : materializations)
3272- allCastOps.push_back (it.first );
3273-
32743275 // Reconcile all UnrealizedConversionCastOps that were inserted by the
3275- // dialect conversion frameworks. (Not the one that were inserted by
3276+ // dialect conversion frameworks. (Not the ones that were inserted by
32763277 // patterns.)
3278+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3279+ &materializations = rewriterImpl.unresolvedMaterializations ;
32773280 SmallVector<UnrealizedConversionCastOp> remainingCastOps;
3278- reconcileUnrealizedCasts (allCastOps , &remainingCastOps);
3281+ reconcileUnrealizedCasts (materializations , &remainingCastOps);
32793282
32803283 // Drop markers.
32813284 for (UnrealizedConversionCastOp castOp : remainingCastOps)
@@ -3303,20 +3306,19 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
33033306// Reconcile Unrealized Casts
33043307// ===----------------------------------------------------------------------===//
33053308
3306- void mlir::reconcileUnrealizedCasts (
3307- ArrayRef<UnrealizedConversionCastOp> castOps,
3309+ // / Try to reconcile all given UnrealizedConversionCastOps and store the
3310+ // / left-over ops in `remainingCastOps` (if provided). See documentation in
3311+ // / DialectConversion.h for more details.
3312+ // / The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3313+ // / algorithm may visit an operand (or user) which is a cast op, but will not
3314+ // / try to reconcile it if not in the filtered set.
3315+ template <typename RangeT>
3316+ static void reconcileUnrealizedCastsImpl (
3317+ RangeT castOps,
3318+ function_ref<bool (UnrealizedConversionCastOp)> isCastOpOfInterestFn,
33083319 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3320+ // A worklist of cast ops to process.
33093321 SetVector<UnrealizedConversionCastOp> worklist (llvm::from_range, castOps);
3310- // This set is maintained only if `remainingCastOps` is provided.
3311- DenseSet<Operation *> erasedOps;
3312-
3313- // Helper function that adds all operands to the worklist that are an
3314- // unrealized_conversion_cast op result.
3315- auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
3316- for (Value v : castOp.getInputs ())
3317- if (auto inputCastOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3318- worklist.insert (inputCastOp);
3319- };
33203322
33213323 // Helper function that return the unrealized_conversion_cast op that
33223324 // defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3337,39 +3339,110 @@ void mlir::reconcileUnrealizedCasts(
33373339 // Process ops in the worklist bottom-to-top.
33383340 while (!worklist.empty ()) {
33393341 UnrealizedConversionCastOp castOp = worklist.pop_back_val ();
3340- if (castOp->use_empty ()) {
3341- // DCE: If the op has no users, erase it. Add the operands to the
3342- // worklist to find additional DCE opportunities.
3343- enqueueOperands (castOp);
3344- if (remainingCastOps)
3345- erasedOps.insert (castOp.getOperation ());
3346- castOp->erase ();
3347- continue ;
3348- }
33493342
33503343 // Traverse the chain of input cast ops to see if an op with the same
33513344 // input types can be found.
33523345 UnrealizedConversionCastOp nextCast = castOp;
33533346 while (nextCast) {
33543347 if (nextCast.getInputs ().getTypes () == castOp.getResultTypes ()) {
3348+ if (llvm::any_of (nextCast.getInputs (), [&](Value v) {
3349+ return v.getDefiningOp () == castOp;
3350+ })) {
3351+ // Ran into a cycle.
3352+ break ;
3353+ }
3354+
33553355 // Found a cast where the input types match the output types of the
3356- // matched op. We can directly use those inputs and the matched op can
3357- // be removed.
3358- enqueueOperands (castOp);
3356+ // matched op. We can directly use those inputs.
33593357 castOp.replaceAllUsesWith (nextCast.getInputs ());
3360- if (remainingCastOps)
3361- erasedOps.insert (castOp.getOperation ());
3362- castOp->erase ();
33633358 break ;
33643359 }
33653360 nextCast = getInputCast (nextCast);
33663361 }
33673362 }
33683363
3369- if (remainingCastOps)
3370- for (UnrealizedConversionCastOp op : castOps)
3371- if (!erasedOps.contains (op.getOperation ()))
3364+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
3365+ // used by an op that is not a cast op.
3366+ DenseSet<Operation *> liveOps;
3367+
3368+ // Helper function that marks the given op and transitively reachable input
3369+ // cast ops as alive.
3370+ auto markOpLive = [&](Operation *rootOp) {
3371+ SmallVector<Operation *> worklist;
3372+ worklist.push_back (rootOp);
3373+ while (!worklist.empty ()) {
3374+ Operation *op = worklist.pop_back_val ();
3375+ if (liveOps.insert (op).second ) {
3376+ // Successfully inserted: process reachable input cast ops.
3377+ for (Value v : op->getOperands ())
3378+ if (auto castOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3379+ if (isCastOpOfInterestFn (castOp))
3380+ worklist.push_back (castOp);
3381+ }
3382+ }
3383+ };
3384+
3385+ // Find all alive cast ops.
3386+ for (UnrealizedConversionCastOp op : castOps) {
3387+ // The op may have been marked live already as being an operand of another
3388+ // live cast op.
3389+ if (liveOps.contains (op.getOperation ()))
3390+ continue ;
3391+ // If any of the users is not a cast op, mark the current op (and its
3392+ // input ops) as live.
3393+ if (llvm::any_of (op->getUsers (), [&](Operation *user) {
3394+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3395+ return !castOp || !isCastOpOfInterestFn (castOp);
3396+ }))
3397+ markOpLive (op);
3398+ }
3399+
3400+ // Erase all dead cast ops.
3401+ for (UnrealizedConversionCastOp op : castOps) {
3402+ if (liveOps.contains (op)) {
3403+ // Op is alive and was not erased. Add it to the remaining cast ops.
3404+ if (remainingCastOps)
33723405 remainingCastOps->push_back (op);
3406+ continue ;
3407+ }
3408+
3409+ // Op is dead. Erase it.
3410+ op->dropAllUses ();
3411+ op->erase ();
3412+ }
3413+ }
3414+
3415+ void mlir::reconcileUnrealizedCasts (
3416+ ArrayRef<UnrealizedConversionCastOp> castOps,
3417+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3418+ // Set of all cast ops for faster lookups.
3419+ DenseSet<UnrealizedConversionCastOp> castOpSet;
3420+ for (UnrealizedConversionCastOp op : castOps)
3421+ castOpSet.insert (op);
3422+ reconcileUnrealizedCasts (castOpSet, remainingCastOps);
3423+ }
3424+
3425+ void mlir::reconcileUnrealizedCasts (
3426+ const DenseSet<UnrealizedConversionCastOp> &castOps,
3427+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3428+ reconcileUnrealizedCastsImpl (
3429+ llvm::make_range (castOps.begin (), castOps.end ()),
3430+ [&](UnrealizedConversionCastOp castOp) {
3431+ return castOps.contains (castOp);
3432+ },
3433+ remainingCastOps);
3434+ }
3435+
3436+ static void mlir::reconcileUnrealizedCasts (
3437+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3438+ &castOps,
3439+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3440+ reconcileUnrealizedCastsImpl (
3441+ castOps.keys (),
3442+ [&](UnrealizedConversionCastOp castOp) {
3443+ return castOps.contains (castOp);
3444+ },
3445+ remainingCastOps);
33733446}
33743447
33753448// ===----------------------------------------------------------------------===//
0 commit comments