@@ -3100,7 +3100,6 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
31003100// ===----------------------------------------------------------------------===//
31013101// OperationConverter
31023102// ===----------------------------------------------------------------------===//
3103-
31043103namespace {
31053104enum OpConversionMode {
31063105 // / In this mode, the conversion will ignore failed conversions to allow
@@ -3118,13 +3117,6 @@ enum OpConversionMode {
31183117} // namespace
31193118
31203119namespace mlir {
3121-
3122- // Predeclaration only.
3123- static void reconcileUnrealizedCasts (
3124- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3125- &castOps,
3126- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
3127-
31283120// This class converts operations to a given conversion target via a set of
31293121// rewrite patterns. The conversion behaves differently depending on the
31303122// conversion mode.
@@ -3272,13 +3264,18 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
32723264 // After a successful conversion, apply rewrites.
32733265 rewriterImpl.applyRewrites ();
32743266
3275- // Reconcile all UnrealizedConversionCastOps that were inserted by the
3276- // dialect conversion frameworks. (Not the ones that were inserted by
3277- // patterns.)
3267+ // Gather all unresolved materializations.
3268+ SmallVector<UnrealizedConversionCastOp> allCastOps;
32783269 const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
32793270 &materializations = rewriterImpl.unresolvedMaterializations ;
3271+ for (auto it : materializations)
3272+ allCastOps.push_back (it.first );
3273+
3274+ // Reconcile all UnrealizedConversionCastOps that were inserted by the
3275+ // dialect conversion frameworks. (Not the one that were inserted by
3276+ // patterns.)
32803277 SmallVector<UnrealizedConversionCastOp> remainingCastOps;
3281- reconcileUnrealizedCasts (materializations , &remainingCastOps);
3278+ reconcileUnrealizedCasts (allCastOps , &remainingCastOps);
32823279
32833280 // Drop markers.
32843281 for (UnrealizedConversionCastOp castOp : remainingCastOps)
@@ -3306,19 +3303,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
33063303// Reconcile Unrealized Casts
33073304// ===----------------------------------------------------------------------===//
33083305
3309- // / Try to reconcile all given UnrealizedConversionCastOps and store the
3310- // / left-over ops in `remainingCastOps` (if provided). See documentation in
3311- // / DialectConversion.h for more details.
3312- // / The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3313- // / algorithm may visit an operand (or user) which is a cast op, but will not
3314- // / try to reconcile it if not in the filtered set.
3315- template <typename RangeT>
3316- static void reconcileUnrealizedCastsImpl (
3317- RangeT castOps,
3318- function_ref<bool (UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3306+ void mlir::reconcileUnrealizedCasts (
3307+ ArrayRef<UnrealizedConversionCastOp> castOps,
33193308 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3320- // A worklist of cast ops to process.
33213309 SetVector<UnrealizedConversionCastOp> worklist (llvm::from_range, castOps);
3310+ // This set is maintained only if `remainingCastOps` is provided.
3311+ DenseSet<Operation *> erasedOps;
3312+
3313+ // Helper function that adds all operands to the worklist that are an
3314+ // unrealized_conversion_cast op result.
3315+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
3316+ for (Value v : castOp.getInputs ())
3317+ if (auto inputCastOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3318+ worklist.insert (inputCastOp);
3319+ };
33223320
33233321 // Helper function that return the unrealized_conversion_cast op that
33243322 // defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3339,110 +3337,39 @@ static void reconcileUnrealizedCastsImpl(
33393337 // Process ops in the worklist bottom-to-top.
33403338 while (!worklist.empty ()) {
33413339 UnrealizedConversionCastOp castOp = worklist.pop_back_val ();
3340+ if (castOp->use_empty ()) {
3341+ // DCE: If the op has no users, erase it. Add the operands to the
3342+ // worklist to find additional DCE opportunities.
3343+ enqueueOperands (castOp);
3344+ if (remainingCastOps)
3345+ erasedOps.insert (castOp.getOperation ());
3346+ castOp->erase ();
3347+ continue ;
3348+ }
33423349
33433350 // Traverse the chain of input cast ops to see if an op with the same
33443351 // input types can be found.
33453352 UnrealizedConversionCastOp nextCast = castOp;
33463353 while (nextCast) {
33473354 if (nextCast.getInputs ().getTypes () == castOp.getResultTypes ()) {
3348- if (llvm::any_of (nextCast.getInputs (), [&](Value v) {
3349- return v.getDefiningOp () == castOp;
3350- })) {
3351- // Ran into a cycle.
3352- break ;
3353- }
3354-
33553355 // Found a cast where the input types match the output types of the
3356- // matched op. We can directly use those inputs.
3356+ // matched op. We can directly use those inputs and the matched op can
3357+ // be removed.
3358+ enqueueOperands (castOp);
33573359 castOp.replaceAllUsesWith (nextCast.getInputs ());
3360+ if (remainingCastOps)
3361+ erasedOps.insert (castOp.getOperation ());
3362+ castOp->erase ();
33583363 break ;
33593364 }
33603365 nextCast = getInputCast (nextCast);
33613366 }
33623367 }
33633368
3364- // A set of all alive cast ops. I.e., ops whose results are (transitively)
3365- // used by an op that is not a cast op.
3366- DenseSet<Operation *> liveOps;
3367-
3368- // Helper function that marks the given op and transitively reachable input
3369- // cast ops as alive.
3370- auto markOpLive = [&](Operation *rootOp) {
3371- SmallVector<Operation *> worklist;
3372- worklist.push_back (rootOp);
3373- while (!worklist.empty ()) {
3374- Operation *op = worklist.pop_back_val ();
3375- if (liveOps.insert (op).second ) {
3376- // Successfully inserted: process reachable input cast ops.
3377- for (Value v : op->getOperands ())
3378- if (auto castOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3379- if (isCastOpOfInterestFn (castOp))
3380- worklist.push_back (castOp);
3381- }
3382- }
3383- };
3384-
3385- // Find all alive cast ops.
3386- for (UnrealizedConversionCastOp op : castOps) {
3387- // The op may have been marked live already as being an operand of another
3388- // live cast op.
3389- if (liveOps.contains (op.getOperation ()))
3390- continue ;
3391- // If any of the users is not a cast op, mark the current op (and its
3392- // input ops) as live.
3393- if (llvm::any_of (op->getUsers (), [&](Operation *user) {
3394- auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3395- return !castOp || !isCastOpOfInterestFn (castOp);
3396- }))
3397- markOpLive (op);
3398- }
3399-
3400- // Erase all dead cast ops.
3401- for (UnrealizedConversionCastOp op : castOps) {
3402- if (liveOps.contains (op)) {
3403- // Op is alive and was not erased. Add it to the remaining cast ops.
3404- if (remainingCastOps)
3369+ if (remainingCastOps)
3370+ for (UnrealizedConversionCastOp op : castOps)
3371+ if (!erasedOps.contains (op.getOperation ()))
34053372 remainingCastOps->push_back (op);
3406- continue ;
3407- }
3408-
3409- // Op is dead. Erase it.
3410- op->dropAllUses ();
3411- op->erase ();
3412- }
3413- }
3414-
3415- void mlir::reconcileUnrealizedCasts (
3416- ArrayRef<UnrealizedConversionCastOp> castOps,
3417- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3418- // Set of all cast ops for faster lookups.
3419- DenseSet<UnrealizedConversionCastOp> castOpSet;
3420- for (UnrealizedConversionCastOp op : castOps)
3421- castOpSet.insert (op);
3422- reconcileUnrealizedCasts (castOpSet, remainingCastOps);
3423- }
3424-
3425- void mlir::reconcileUnrealizedCasts (
3426- const DenseSet<UnrealizedConversionCastOp> &castOps,
3427- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3428- reconcileUnrealizedCastsImpl (
3429- llvm::make_range (castOps.begin (), castOps.end ()),
3430- [&](UnrealizedConversionCastOp castOp) {
3431- return castOps.contains (castOp);
3432- },
3433- remainingCastOps);
3434- }
3435-
3436- static void mlir::reconcileUnrealizedCasts (
3437- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3438- &castOps,
3439- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3440- reconcileUnrealizedCastsImpl (
3441- castOps.keys (),
3442- [&](UnrealizedConversionCastOp castOp) {
3443- return castOps.contains (castOp);
3444- },
3445- remainingCastOps);
34463373}
34473374
34483375// ===----------------------------------------------------------------------===//
0 commit comments