Skip to content

Commit 3bd5963

Browse files
authored
[MLIR][SCF] Speed up ConditionPropagation (#166080)
Introduce a cache to avoid looking up then/else region nesting through `isAncestor` calls repeatedly. This gets expensive for large inputs with lots of pointer chasing. Fixes #166039
1 parent 28d3194 commit 3bd5963

File tree

1 file changed

+44
-3
lines changed

1 file changed

+44
-3
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2565,6 +2565,39 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
25652565
struct ConditionPropagation : public OpRewritePattern<IfOp> {
25662566
using OpRewritePattern<IfOp>::OpRewritePattern;
25672567

2568+
/// Kind of parent region in the ancestor cache.
2569+
enum class Parent { Then, Else, None };
2570+
2571+
/// Returns the kind of region ("then", "else", or "none") of the
2572+
/// IfOp that the given region is transitively nested in. Updates
2573+
/// the cache accordingly.
2574+
static Parent getParentType(Region *toCheck, IfOp op,
2575+
DenseMap<Region *, Parent> &cache,
2576+
Region *endRegion) {
2577+
SmallVector<Region *> seen;
2578+
while (toCheck != endRegion) {
2579+
auto found = cache.find(toCheck);
2580+
if (found != cache.end())
2581+
return found->second;
2582+
seen.push_back(toCheck);
2583+
if (&op.getThenRegion() == toCheck) {
2584+
for (Region *region : seen)
2585+
cache[region] = Parent::Then;
2586+
return Parent::Then;
2587+
}
2588+
if (&op.getElseRegion() == toCheck) {
2589+
for (Region *region : seen)
2590+
cache[region] = Parent::Else;
2591+
return Parent::Else;
2592+
}
2593+
toCheck = toCheck->getParentRegion();
2594+
}
2595+
2596+
for (Region *region : seen)
2597+
cache[region] = Parent::None;
2598+
return Parent::None;
2599+
}
2600+
25682601
LogicalResult matchAndRewrite(IfOp op,
25692602
PatternRewriter &rewriter) const override {
25702603
// Early exit if the condition is constant since replacing a constant
@@ -2580,9 +2613,12 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
25802613
Value constantTrue = nullptr;
25812614
Value constantFalse = nullptr;
25822615

2616+
DenseMap<Region *, Parent> cache;
25832617
for (OpOperand &use :
25842618
llvm::make_early_inc_range(op.getCondition().getUses())) {
2585-
if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2619+
switch (getParentType(use.getOwner()->getParentRegion(), op, cache,
2620+
op.getCondition().getParentRegion())) {
2621+
case Parent::Then: {
25862622
changed = true;
25872623

25882624
if (!constantTrue)
@@ -2591,8 +2627,9 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
25912627

25922628
rewriter.modifyOpInPlace(use.getOwner(),
25932629
[&]() { use.set(constantTrue); });
2594-
} else if (op.getElseRegion().isAncestor(
2595-
use.getOwner()->getParentRegion())) {
2630+
break;
2631+
}
2632+
case Parent::Else: {
25962633
changed = true;
25972634

25982635
if (!constantFalse)
@@ -2601,6 +2638,10 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
26012638

26022639
rewriter.modifyOpInPlace(use.getOwner(),
26032640
[&]() { use.set(constantFalse); });
2641+
break;
2642+
}
2643+
case Parent::None:
2644+
break;
26042645
}
26052646
}
26062647

0 commit comments

Comments
 (0)