Skip to content

Commit 9d11acc

Browse files
authored
[mlir] move if-condition propagation to a standalone pass (#150278)
This offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist. It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites.
1 parent bc1f85d commit 9d11acc

File tree

6 files changed

+141
-97
lines changed

6 files changed

+141
-97
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
4141
let constructor = "mlir::createForLoopSpecializationPass()";
4242
}
4343

44+
def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
45+
let summary = "Replace usages of if condition with true/false constants in "
46+
"the conditional regions";
47+
let dependentDialects = ["arith::ArithDialect"];
48+
}
49+
4450
def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
4551
let summary = "Fuse adjacent parallel loops";
4652
let constructor = "mlir::createParallelLoopFusionPass()";

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 2 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2414,65 +2414,6 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
24142414
}
24152415
};
24162416

2417-
/// Allow the true region of an if to assume the condition is true
2418-
/// and vice versa. For example:
2419-
///
2420-
/// scf.if %cmp {
2421-
/// print(%cmp)
2422-
/// }
2423-
///
2424-
/// becomes
2425-
///
2426-
/// scf.if %cmp {
2427-
/// print(true)
2428-
/// }
2429-
///
2430-
struct ConditionPropagation : public OpRewritePattern<IfOp> {
2431-
using OpRewritePattern<IfOp>::OpRewritePattern;
2432-
2433-
LogicalResult matchAndRewrite(IfOp op,
2434-
PatternRewriter &rewriter) const override {
2435-
// Early exit if the condition is constant since replacing a constant
2436-
// in the body with another constant isn't a simplification.
2437-
if (matchPattern(op.getCondition(), m_Constant()))
2438-
return failure();
2439-
2440-
bool changed = false;
2441-
mlir::Type i1Ty = rewriter.getI1Type();
2442-
2443-
// These variables serve to prevent creating duplicate constants
2444-
// and hold constant true or false values.
2445-
Value constantTrue = nullptr;
2446-
Value constantFalse = nullptr;
2447-
2448-
for (OpOperand &use :
2449-
llvm::make_early_inc_range(op.getCondition().getUses())) {
2450-
if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2451-
changed = true;
2452-
2453-
if (!constantTrue)
2454-
constantTrue = rewriter.create<arith::ConstantOp>(
2455-
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2456-
2457-
rewriter.modifyOpInPlace(use.getOwner(),
2458-
[&]() { use.set(constantTrue); });
2459-
} else if (op.getElseRegion().isAncestor(
2460-
use.getOwner()->getParentRegion())) {
2461-
changed = true;
2462-
2463-
if (!constantFalse)
2464-
constantFalse = rewriter.create<arith::ConstantOp>(
2465-
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2466-
2467-
rewriter.modifyOpInPlace(use.getOwner(),
2468-
[&]() { use.set(constantFalse); });
2469-
}
2470-
}
2471-
2472-
return success(changed);
2473-
}
2474-
};
2475-
24762417
/// Remove any statements from an if that are equivalent to the condition
24772418
/// or its negation. For example:
24782419
///
@@ -2854,9 +2795,8 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
28542795

28552796
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
28562797
MLIRContext *context) {
2857-
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2858-
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2859-
RemoveStaticCondition, RemoveUnusedResults,
2798+
results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
2799+
RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
28602800
ReplaceIfYieldWithConditionOrValue>(context);
28612801
}
28622802

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
44
ForallToFor.cpp
55
ForallToParallel.cpp
66
ForToWhile.cpp
7+
IfConditionPropagation.cpp
78
LoopCanonicalization.cpp
89
LoopPipelining.cpp
910
LoopRangeFolding.cpp
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
//===- IfConditionPropagation.cpp -----------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains a pass for constant propagation of the condition of an
10+
// `scf.if` into its then and else regions as true and false respectively.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/SCF/IR/SCF.h"
16+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
17+
18+
using namespace mlir;
19+
20+
namespace mlir {
21+
#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
22+
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
23+
} // namespace mlir
24+
25+
/// Traverses the IR recursively (on region tree) and updates the uses of a
26+
/// value also as the condition of an `scf.if` to either `true` or `false`
27+
/// constants in the `then` and `else regions. This is done as a single
28+
/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
29+
/// traversing, the function maintains the set of visited regions to quickly
30+
/// identify whether the value belong to a region that is known to be nested in
31+
/// the `then` or `else` branch of a specific loop.
32+
static void propagateIfConditionsImpl(Operation *root,
33+
llvm::SmallPtrSet<Region *, 8> &visited) {
34+
if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
35+
llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
36+
// Visit the "then" region, collect children.
37+
for (Block &block : scfIf.getThenRegion()) {
38+
for (Operation &op : block) {
39+
propagateIfConditionsImpl(&op, thenChildren);
40+
}
41+
}
42+
43+
// Visit the "else" region, collect children.
44+
for (Block &block : scfIf.getElseRegion()) {
45+
for (Operation &op : block) {
46+
propagateIfConditionsImpl(&op, elseChildren);
47+
}
48+
}
49+
50+
// Update uses to point to constants instead.
51+
OpBuilder builder(scfIf);
52+
Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
53+
/*value=*/true, /*width=*/1);
54+
Value falseValue =
55+
arith::ConstantIntOp::create(builder, scfIf.getLoc(),
56+
/*value=*/false, /*width=*/1);
57+
58+
for (OpOperand &use : scfIf.getCondition().getUses()) {
59+
if (thenChildren.contains(use.getOwner()->getParentRegion()))
60+
use.set(trueValue);
61+
else if (elseChildren.contains(use.getOwner()->getParentRegion()))
62+
use.set(falseValue);
63+
}
64+
if (trueValue.getUses().empty())
65+
trueValue.getDefiningOp()->erase();
66+
if (falseValue.getUses().empty())
67+
falseValue.getDefiningOp()->erase();
68+
69+
// Append the two lists of children and return them.
70+
visited.insert_range(thenChildren);
71+
visited.insert_range(elseChildren);
72+
return;
73+
}
74+
75+
for (Region &region : root->getRegions()) {
76+
for (Block &block : region) {
77+
for (Operation &op : block) {
78+
propagateIfConditionsImpl(&op, visited);
79+
}
80+
}
81+
}
82+
}
83+
84+
/// Traverses the IR recursively (on region tree) and updates the uses of a
85+
/// value also as the condition of an `scf.if` to either `true` or `false`
86+
/// constants in the `then` and `else regions
87+
static void propagateIfConditions(Operation *root) {
88+
llvm::SmallPtrSet<Region *, 8> visited;
89+
propagateIfConditionsImpl(root, visited);
90+
}
91+
92+
namespace {
93+
/// Pass entrypoint.
94+
struct SCFIfConditionPropagationPass
95+
: impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
96+
void runOnOperation() override { propagateIfConditions(getOperation()); }
97+
};
98+
} // namespace

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -867,41 +867,6 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
867867

868868
// -----
869869

870-
// CHECK-LABEL: @cond_prop
871-
func.func @cond_prop(%arg0 : i1) -> index {
872-
%res = scf.if %arg0 -> index {
873-
%res1 = scf.if %arg0 -> index {
874-
%v1 = "test.get_some_value1"() : () -> index
875-
scf.yield %v1 : index
876-
} else {
877-
%v2 = "test.get_some_value2"() : () -> index
878-
scf.yield %v2 : index
879-
}
880-
scf.yield %res1 : index
881-
} else {
882-
%res2 = scf.if %arg0 -> index {
883-
%v3 = "test.get_some_value3"() : () -> index
884-
scf.yield %v3 : index
885-
} else {
886-
%v4 = "test.get_some_value4"() : () -> index
887-
scf.yield %v4 : index
888-
}
889-
scf.yield %res2 : index
890-
}
891-
return %res : index
892-
}
893-
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
894-
// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index
895-
// CHECK-NEXT: scf.yield %[[c1]] : index
896-
// CHECK-NEXT: } else {
897-
// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index
898-
// CHECK-NEXT: scf.yield %[[c4]] : index
899-
// CHECK-NEXT: }
900-
// CHECK-NEXT: return %[[if]] : index
901-
// CHECK-NEXT:}
902-
903-
// -----
904-
905870
// CHECK-LABEL: @replace_if_with_cond1
906871
func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
907872
%true = arith.constant true
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
2+
3+
// CHECK-LABEL: @cond_prop
4+
func.func @cond_prop(%arg0 : i1) -> index {
5+
%res = scf.if %arg0 -> index {
6+
%res1 = scf.if %arg0 -> index {
7+
%v1 = "test.get_some_value1"() : () -> index
8+
scf.yield %v1 : index
9+
} else {
10+
%v2 = "test.get_some_value2"() : () -> index
11+
scf.yield %v2 : index
12+
}
13+
scf.yield %res1 : index
14+
} else {
15+
%res2 = scf.if %arg0 -> index {
16+
%v3 = "test.get_some_value3"() : () -> index
17+
scf.yield %v3 : index
18+
} else {
19+
%v4 = "test.get_some_value4"() : () -> index
20+
scf.yield %v4 : index
21+
}
22+
scf.yield %res2 : index
23+
}
24+
return %res : index
25+
}
26+
// CHECK: %[[if:.+]] = scf.if %arg0 -> (index) {
27+
// CHECK: %[[c1:.+]] = "test.get_some_value1"() : () -> index
28+
// CHECK: scf.yield %[[c1]] : index
29+
// CHECK: } else {
30+
// CHECK: %[[c4:.+]] = "test.get_some_value4"() : () -> index
31+
// CHECK: scf.yield %[[c4]] : index
32+
// CHECK: }
33+
// CHECK: return %[[if]] : index
34+
// CHECK:}

0 commit comments

Comments
 (0)