@@ -2565,6 +2565,39 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
25652565struct ConditionPropagation : public OpRewritePattern <IfOp> {
25662566 using OpRewritePattern<IfOp>::OpRewritePattern;
25672567
2568+ // / Kind of parent region in the ancestor cache.
2569+ enum class Parent { Then, Else, None };
2570+
2571+ // / Returns the kind of region ("then", "else", or "none") of the
2572+ // / IfOp that the given region is transitively nested in. Updates
2573+ // / the cache accordingly.
2574+ static Parent getParentType (Region *toCheck, IfOp op,
2575+ DenseMap<Region *, Parent> &cache,
2576+ Region *endRegion) {
2577+ SmallVector<Region *> seen;
2578+ while (toCheck != endRegion) {
2579+ auto found = cache.find (toCheck);
2580+ if (found != cache.end ())
2581+ return found->second ;
2582+ seen.push_back (toCheck);
2583+ if (&op.getThenRegion () == toCheck) {
2584+ for (Region *region : seen)
2585+ cache[region] = Parent::Then;
2586+ return Parent::Then;
2587+ }
2588+ if (&op.getElseRegion () == toCheck) {
2589+ for (Region *region : seen)
2590+ cache[region] = Parent::Else;
2591+ return Parent::Else;
2592+ }
2593+ toCheck = toCheck->getParentRegion ();
2594+ }
2595+
2596+ for (Region *region : seen)
2597+ cache[region] = Parent::None;
2598+ return Parent::None;
2599+ }
2600+
25682601 LogicalResult matchAndRewrite (IfOp op,
25692602 PatternRewriter &rewriter) const override {
25702603 // Early exit if the condition is constant since replacing a constant
@@ -2580,9 +2613,12 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
25802613 Value constantTrue = nullptr ;
25812614 Value constantFalse = nullptr ;
25822615
2616+ DenseMap<Region *, Parent> cache;
25832617 for (OpOperand &use :
25842618 llvm::make_early_inc_range (op.getCondition ().getUses ())) {
2585- if (op.getThenRegion ().isAncestor (use.getOwner ()->getParentRegion ())) {
2619+ switch (getParentType (use.getOwner ()->getParentRegion (), op, cache,
2620+ op.getCondition ().getParentRegion ())) {
2621+ case Parent::Then: {
25862622 changed = true ;
25872623
25882624 if (!constantTrue)
@@ -2591,8 +2627,9 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
25912627
25922628 rewriter.modifyOpInPlace (use.getOwner (),
25932629 [&]() { use.set (constantTrue); });
2594- } else if (op.getElseRegion ().isAncestor (
2595- use.getOwner ()->getParentRegion ())) {
2630+ break ;
2631+ }
2632+ case Parent::Else: {
25962633 changed = true ;
25972634
25982635 if (!constantFalse)
@@ -2601,6 +2638,10 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
26012638
26022639 rewriter.modifyOpInPlace (use.getOwner (),
26032640 [&]() { use.set (constantFalse); });
2641+ break ;
2642+ }
2643+ case Parent::None:
2644+ break ;
26042645 }
26052646 }
26062647
0 commit comments