@@ -3097,6 +3097,151 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
30973097 return minDepth;
30983098}
30993099
3100+ // ===----------------------------------------------------------------------===//
3101+ // Reconcile Unrealized Casts
3102+ // ===----------------------------------------------------------------------===//
3103+
3104+ // / Try to reconcile all given UnrealizedConversionCastOps and store the
3105+ // / left-over ops in `remainingCastOps` (if provided). See documentation in
3106+ // / DialectConversion.h for more details.
3107+ // / The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3108+ // / algorithm may visit an operand (or user) which is a cast op, but will not
3109+ // / try to reconcile it if not in the filtered set.
3110+ template <typename RangeT>
3111+ static void reconcileUnrealizedCastsImpl (
3112+ RangeT castOps,
3113+ function_ref<bool (UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3114+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3115+ // A worklist of cast ops to process.
3116+ SetVector<UnrealizedConversionCastOp> worklist (llvm::from_range, castOps);
3117+
3118+ // Helper function that return the unrealized_conversion_cast op that
3119+ // defines all inputs of the given op (in the same order). Return "nullptr"
3120+ // if there is no such op.
3121+ auto getInputCast =
3122+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3123+ if (castOp.getInputs ().empty ())
3124+ return {};
3125+ auto inputCastOp =
3126+ castOp.getInputs ().front ().getDefiningOp <UnrealizedConversionCastOp>();
3127+ if (!inputCastOp)
3128+ return {};
3129+ if (inputCastOp.getOutputs () != castOp.getInputs ())
3130+ return {};
3131+ return inputCastOp;
3132+ };
3133+
3134+ // Process ops in the worklist bottom-to-top.
3135+ while (!worklist.empty ()) {
3136+ UnrealizedConversionCastOp castOp = worklist.pop_back_val ();
3137+
3138+ // Traverse the chain of input cast ops to see if an op with the same
3139+ // input types can be found.
3140+ UnrealizedConversionCastOp nextCast = castOp;
3141+ while (nextCast) {
3142+ if (nextCast.getInputs ().getTypes () == castOp.getResultTypes ()) {
3143+ if (llvm::any_of (nextCast.getInputs (), [&](Value v) {
3144+ return v.getDefiningOp () == castOp;
3145+ })) {
3146+ // Ran into a cycle.
3147+ break ;
3148+ }
3149+
3150+ // Found a cast where the input types match the output types of the
3151+ // matched op. We can directly use those inputs.
3152+ castOp.replaceAllUsesWith (nextCast.getInputs ());
3153+ break ;
3154+ }
3155+ nextCast = getInputCast (nextCast);
3156+ }
3157+ }
3158+
3159+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
3160+ // used by an op that is not a cast op.
3161+ DenseSet<Operation *> liveOps;
3162+
3163+ // Helper function that marks the given op and transitively reachable input
3164+ // cast ops as alive.
3165+ auto markOpLive = [&](Operation *rootOp) {
3166+ SmallVector<Operation *> worklist;
3167+ worklist.push_back (rootOp);
3168+ while (!worklist.empty ()) {
3169+ Operation *op = worklist.pop_back_val ();
3170+ if (liveOps.insert (op).second ) {
3171+ // Successfully inserted: process reachable input cast ops.
3172+ for (Value v : op->getOperands ())
3173+ if (auto castOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3174+ if (isCastOpOfInterestFn (castOp))
3175+ worklist.push_back (castOp);
3176+ }
3177+ }
3178+ };
3179+
3180+ // Find all alive cast ops.
3181+ for (UnrealizedConversionCastOp op : castOps) {
3182+ // The op may have been marked live already as being an operand of another
3183+ // live cast op.
3184+ if (liveOps.contains (op.getOperation ()))
3185+ continue ;
3186+ // If any of the users is not a cast op, mark the current op (and its
3187+ // input ops) as live.
3188+ if (llvm::any_of (op->getUsers (), [&](Operation *user) {
3189+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3190+ return !castOp || !isCastOpOfInterestFn (castOp);
3191+ }))
3192+ markOpLive (op);
3193+ }
3194+
3195+ // Erase all dead cast ops.
3196+ for (UnrealizedConversionCastOp op : castOps) {
3197+ if (liveOps.contains (op)) {
3198+ // Op is alive and was not erased. Add it to the remaining cast ops.
3199+ if (remainingCastOps)
3200+ remainingCastOps->push_back (op);
3201+ continue ;
3202+ }
3203+
3204+ // Op is dead. Erase it.
3205+ op->dropAllUses ();
3206+ op->erase ();
3207+ }
3208+ }
3209+
3210+ void mlir::reconcileUnrealizedCasts (
3211+ ArrayRef<UnrealizedConversionCastOp> castOps,
3212+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3213+ // Set of all cast ops for faster lookups.
3214+ DenseSet<UnrealizedConversionCastOp> castOpSet;
3215+ for (UnrealizedConversionCastOp op : castOps)
3216+ castOpSet.insert (op);
3217+ reconcileUnrealizedCasts (castOpSet, remainingCastOps);
3218+ }
3219+
3220+ void mlir::reconcileUnrealizedCasts (
3221+ const DenseSet<UnrealizedConversionCastOp> &castOps,
3222+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3223+ reconcileUnrealizedCastsImpl (
3224+ llvm::make_range (castOps.begin (), castOps.end ()),
3225+ [&](UnrealizedConversionCastOp castOp) {
3226+ return castOps.contains (castOp);
3227+ },
3228+ remainingCastOps);
3229+ }
3230+
3231+ namespace mlir {
3232+ static void reconcileUnrealizedCasts (
3233+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3234+ &castOps,
3235+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3236+ reconcileUnrealizedCastsImpl (
3237+ castOps.keys (),
3238+ [&](UnrealizedConversionCastOp castOp) {
3239+ return castOps.contains (castOp);
3240+ },
3241+ remainingCastOps);
3242+ }
3243+ } // namespace mlir
3244+
31003245// ===----------------------------------------------------------------------===//
31013246// OperationConverter
31023247// ===----------------------------------------------------------------------===//
@@ -3118,13 +3263,6 @@ enum OpConversionMode {
31183263} // namespace
31193264
31203265namespace mlir {
3121-
3122- // Predeclaration only.
3123- static void reconcileUnrealizedCasts (
3124- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3125- &castOps,
3126- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
3127-
31283266// This class converts operations to a given conversion target via a set of
31293267// rewrite patterns. The conversion behaves differently depending on the
31303268// conversion mode.
@@ -3302,149 +3440,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
33023440 return success ();
33033441}
33043442
3305- // ===----------------------------------------------------------------------===//
3306- // Reconcile Unrealized Casts
3307- // ===----------------------------------------------------------------------===//
3308-
3309- // / Try to reconcile all given UnrealizedConversionCastOps and store the
3310- // / left-over ops in `remainingCastOps` (if provided). See documentation in
3311- // / DialectConversion.h for more details.
3312- // / The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3313- // / algorithm may visit an operand (or user) which is a cast op, but will not
3314- // / try to reconcile it if not in the filtered set.
3315- template <typename RangeT>
3316- static void reconcileUnrealizedCastsImpl (
3317- RangeT castOps,
3318- function_ref<bool (UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3319- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3320- // A worklist of cast ops to process.
3321- SetVector<UnrealizedConversionCastOp> worklist (llvm::from_range, castOps);
3322-
3323- // Helper function that return the unrealized_conversion_cast op that
3324- // defines all inputs of the given op (in the same order). Return "nullptr"
3325- // if there is no such op.
3326- auto getInputCast =
3327- [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3328- if (castOp.getInputs ().empty ())
3329- return {};
3330- auto inputCastOp =
3331- castOp.getInputs ().front ().getDefiningOp <UnrealizedConversionCastOp>();
3332- if (!inputCastOp)
3333- return {};
3334- if (inputCastOp.getOutputs () != castOp.getInputs ())
3335- return {};
3336- return inputCastOp;
3337- };
3338-
3339- // Process ops in the worklist bottom-to-top.
3340- while (!worklist.empty ()) {
3341- UnrealizedConversionCastOp castOp = worklist.pop_back_val ();
3342-
3343- // Traverse the chain of input cast ops to see if an op with the same
3344- // input types can be found.
3345- UnrealizedConversionCastOp nextCast = castOp;
3346- while (nextCast) {
3347- if (nextCast.getInputs ().getTypes () == castOp.getResultTypes ()) {
3348- if (llvm::any_of (nextCast.getInputs (), [&](Value v) {
3349- return v.getDefiningOp () == castOp;
3350- })) {
3351- // Ran into a cycle.
3352- break ;
3353- }
3354-
3355- // Found a cast where the input types match the output types of the
3356- // matched op. We can directly use those inputs.
3357- castOp.replaceAllUsesWith (nextCast.getInputs ());
3358- break ;
3359- }
3360- nextCast = getInputCast (nextCast);
3361- }
3362- }
3363-
3364- // A set of all alive cast ops. I.e., ops whose results are (transitively)
3365- // used by an op that is not a cast op.
3366- DenseSet<Operation *> liveOps;
3367-
3368- // Helper function that marks the given op and transitively reachable input
3369- // cast ops as alive.
3370- auto markOpLive = [&](Operation *rootOp) {
3371- SmallVector<Operation *> worklist;
3372- worklist.push_back (rootOp);
3373- while (!worklist.empty ()) {
3374- Operation *op = worklist.pop_back_val ();
3375- if (liveOps.insert (op).second ) {
3376- // Successfully inserted: process reachable input cast ops.
3377- for (Value v : op->getOperands ())
3378- if (auto castOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3379- if (isCastOpOfInterestFn (castOp))
3380- worklist.push_back (castOp);
3381- }
3382- }
3383- };
3384-
3385- // Find all alive cast ops.
3386- for (UnrealizedConversionCastOp op : castOps) {
3387- // The op may have been marked live already as being an operand of another
3388- // live cast op.
3389- if (liveOps.contains (op.getOperation ()))
3390- continue ;
3391- // If any of the users is not a cast op, mark the current op (and its
3392- // input ops) as live.
3393- if (llvm::any_of (op->getUsers (), [&](Operation *user) {
3394- auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3395- return !castOp || !isCastOpOfInterestFn (castOp);
3396- }))
3397- markOpLive (op);
3398- }
3399-
3400- // Erase all dead cast ops.
3401- for (UnrealizedConversionCastOp op : castOps) {
3402- if (liveOps.contains (op)) {
3403- // Op is alive and was not erased. Add it to the remaining cast ops.
3404- if (remainingCastOps)
3405- remainingCastOps->push_back (op);
3406- continue ;
3407- }
3408-
3409- // Op is dead. Erase it.
3410- op->dropAllUses ();
3411- op->erase ();
3412- }
3413- }
3414-
3415- void mlir::reconcileUnrealizedCasts (
3416- ArrayRef<UnrealizedConversionCastOp> castOps,
3417- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3418- // Set of all cast ops for faster lookups.
3419- DenseSet<UnrealizedConversionCastOp> castOpSet;
3420- for (UnrealizedConversionCastOp op : castOps)
3421- castOpSet.insert (op);
3422- reconcileUnrealizedCasts (castOpSet, remainingCastOps);
3423- }
3424-
3425- void mlir::reconcileUnrealizedCasts (
3426- const DenseSet<UnrealizedConversionCastOp> &castOps,
3427- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3428- reconcileUnrealizedCastsImpl (
3429- llvm::make_range (castOps.begin (), castOps.end ()),
3430- [&](UnrealizedConversionCastOp castOp) {
3431- return castOps.contains (castOp);
3432- },
3433- remainingCastOps);
3434- }
3435-
3436- static void mlir::reconcileUnrealizedCasts (
3437- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3438- &castOps,
3439- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3440- reconcileUnrealizedCastsImpl (
3441- castOps.keys (),
3442- [&](UnrealizedConversionCastOp castOp) {
3443- return castOps.contains (castOp);
3444- },
3445- remainingCastOps);
3446- }
3447-
34483443// ===----------------------------------------------------------------------===//
34493444// Type Conversion
34503445// ===----------------------------------------------------------------------===//
0 commit comments