-
Couldn't load subscription status.
- Fork 15k
[mlir] move if-condition propagation to a standalone pass #150278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir Author: Oleksandr "Alex" Zinenko (ftynse) ChangesThis offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist. It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites. Full diff: https://github.com/llvm/llvm-project/pull/150278.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..ca2510bb53af9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -41,6 +41,12 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
let constructor = "mlir::createForLoopSpecializationPass()";
}
+def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
+ let summary = "Replace usages of if condition with true/false constants in "
+ "the conditional regions";
+ let dependentDialects = ["arith::ArithDialect"];
+}
+
def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
let summary = "Fuse adjacent parallel loops";
let constructor = "mlir::createParallelLoopFusionPass()";
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 00c31a1500e17..6cb61900928d6 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2412,65 +2412,6 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
}
};
-/// Allow the true region of an if to assume the condition is true
-/// and vice versa. For example:
-///
-/// scf.if %cmp {
-/// print(%cmp)
-/// }
-///
-/// becomes
-///
-/// scf.if %cmp {
-/// print(true)
-/// }
-///
-struct ConditionPropagation : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IfOp op,
- PatternRewriter &rewriter) const override {
- // Early exit if the condition is constant since replacing a constant
- // in the body with another constant isn't a simplification.
- if (matchPattern(op.getCondition(), m_Constant()))
- return failure();
-
- bool changed = false;
- mlir::Type i1Ty = rewriter.getI1Type();
-
- // These variables serve to prevent creating duplicate constants
- // and hold constant true or false values.
- Value constantTrue = nullptr;
- Value constantFalse = nullptr;
-
- for (OpOperand &use :
- llvm::make_early_inc_range(op.getCondition().getUses())) {
- if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
- changed = true;
-
- if (!constantTrue)
- constantTrue = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
-
- rewriter.modifyOpInPlace(use.getOwner(),
- [&]() { use.set(constantTrue); });
- } else if (op.getElseRegion().isAncestor(
- use.getOwner()->getParentRegion())) {
- changed = true;
-
- if (!constantFalse)
- constantFalse = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
-
- rewriter.modifyOpInPlace(use.getOwner(),
- [&]() { use.set(constantFalse); });
- }
- }
-
- return success(changed);
- }
-};
-
/// Remove any statements from an if that are equivalent to the condition
/// or its negation. For example:
///
@@ -2852,9 +2793,8 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
- ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
- RemoveStaticCondition, RemoveUnusedResults,
+ results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
+ RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
ReplaceIfYieldWithConditionOrValue>(context);
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..6d3bafbbc90e4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ForallToFor.cpp
ForallToParallel.cpp
ForToWhile.cpp
+ IfConditionPropagation.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
LoopRangeFolding.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
new file mode 100644
index 0000000000000..be8d0e805a7a4
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
@@ -0,0 +1,96 @@
+//===- IfConditionPropagation.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a pass for constant propagation of the condition of an
+// `scf.if` into its then and else regions as true and false respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions. This is done as a single
+/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
+/// traversing, the function maintains the set of visited regions to quickly
+/// identify whether the value belong to a region that is known to be nested in
+/// the `then` or `else` branch of a specific loop.
+static void propagateIfConditionsImpl(Operation *root,
+ llvm::SmallPtrSet<Region *, 8> &visited) {
+ if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
+ llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
+ // Visit the "then" region, collect children.
+ for (Block &block : scfIf.getThenRegion()) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, thenChildren);
+ }
+ }
+
+ // Visit the "else" region, collect children.
+ for (Block &block : scfIf.getElseRegion()) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, elseChildren);
+ }
+ }
+
+ // Update uses to point to constants instead.
+ OpBuilder builder(scfIf);
+ Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+ builder.getBoolAttr(true));
+ Value falseValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+ builder.getBoolAttr(false));
+
+ for (OpOperand &use : scfIf.getCondition().getUses()) {
+ if (thenChildren.contains(use.getOwner()->getParentRegion()))
+ use.set(trueValue);
+ else if (elseChildren.contains(use.getOwner()->getParentRegion()))
+ use.set(falseValue);
+ }
+ if (trueValue.getUses().empty())
+ trueValue.getDefiningOp()->erase();
+ if (falseValue.getUses().empty())
+ falseValue.getDefiningOp()->erase();
+
+ // Append the two lists of children and return them.
+ visited.insert_range(thenChildren);
+ visited.insert_range(elseChildren);
+ return;
+ }
+
+ for (Region ®ion : root->getRegions()) {
+ for (Block &block : region) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, visited);
+ }
+ }
+ }
+}
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions
+static void propagateIfConditions(Operation *root) {
+ llvm::SmallPtrSet<Region *, 8> visited;
+ propagateIfConditionsImpl(root, visited);
+}
+
+namespace {
+/// Pass entrypoint.
+struct SCFIfConditionPropagationPass
+ : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
+ void runOnOperation() override { propagateIfConditions(getOperation()); }
+};
+} // namespace
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8ba8013d008a0..12d30e17f4a8f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -867,41 +867,6 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
// -----
-// CHECK-LABEL: @cond_prop
-func.func @cond_prop(%arg0 : i1) -> index {
- %res = scf.if %arg0 -> index {
- %res1 = scf.if %arg0 -> index {
- %v1 = "test.get_some_value1"() : () -> index
- scf.yield %v1 : index
- } else {
- %v2 = "test.get_some_value2"() : () -> index
- scf.yield %v2 : index
- }
- scf.yield %res1 : index
- } else {
- %res2 = scf.if %arg0 -> index {
- %v3 = "test.get_some_value3"() : () -> index
- scf.yield %v3 : index
- } else {
- %v4 = "test.get_some_value4"() : () -> index
- scf.yield %v4 : index
- }
- scf.yield %res2 : index
- }
- return %res : index
-}
-// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index
-// CHECK-NEXT: scf.yield %[[c1]] : index
-// CHECK-NEXT: } else {
-// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index
-// CHECK-NEXT: scf.yield %[[c4]] : index
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[if]] : index
-// CHECK-NEXT:}
-
-// -----
-
// CHECK-LABEL: @replace_if_with_cond1
func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
%true = arith.constant true
diff --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir b/mlir/test/Dialect/SCF/if-cond-prop.mlir
new file mode 100644
index 0000000000000..99d113f672014
--- /dev/null
+++ b/mlir/test/Dialect/SCF/if-cond-prop.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @cond_prop
+func.func @cond_prop(%arg0 : i1) -> index {
+ %res = scf.if %arg0 -> index {
+ %res1 = scf.if %arg0 -> index {
+ %v1 = "test.get_some_value1"() : () -> index
+ scf.yield %v1 : index
+ } else {
+ %v2 = "test.get_some_value2"() : () -> index
+ scf.yield %v2 : index
+ }
+ scf.yield %res1 : index
+ } else {
+ %res2 = scf.if %arg0 -> index {
+ %v3 = "test.get_some_value3"() : () -> index
+ scf.yield %v3 : index
+ } else {
+ %v4 = "test.get_some_value4"() : () -> index
+ scf.yield %v4 : index
+ }
+ scf.yield %res2 : index
+ }
+ return %res : index
+}
+// CHECK: %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK: %[[c1:.+]] = "test.get_some_value1"() : () -> index
+// CHECK: scf.yield %[[c1]] : index
+// CHECK: } else {
+// CHECK: %[[c4:.+]] = "test.get_some_value4"() : () -> index
+// CHECK: scf.yield %[[c4]] : index
+// CHECK: }
+// CHECK: return %[[if]] : index
+// CHECK:}
|
|
@llvm/pr-subscribers-mlir-scf Author: Oleksandr "Alex" Zinenko (ftynse) ChangesThis offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist. It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites. Full diff: https://github.com/llvm/llvm-project/pull/150278.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..ca2510bb53af9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -41,6 +41,12 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
let constructor = "mlir::createForLoopSpecializationPass()";
}
+def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
+ let summary = "Replace usages of if condition with true/false constants in "
+ "the conditional regions";
+ let dependentDialects = ["arith::ArithDialect"];
+}
+
def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
let summary = "Fuse adjacent parallel loops";
let constructor = "mlir::createParallelLoopFusionPass()";
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 00c31a1500e17..6cb61900928d6 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2412,65 +2412,6 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
}
};
-/// Allow the true region of an if to assume the condition is true
-/// and vice versa. For example:
-///
-/// scf.if %cmp {
-/// print(%cmp)
-/// }
-///
-/// becomes
-///
-/// scf.if %cmp {
-/// print(true)
-/// }
-///
-struct ConditionPropagation : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IfOp op,
- PatternRewriter &rewriter) const override {
- // Early exit if the condition is constant since replacing a constant
- // in the body with another constant isn't a simplification.
- if (matchPattern(op.getCondition(), m_Constant()))
- return failure();
-
- bool changed = false;
- mlir::Type i1Ty = rewriter.getI1Type();
-
- // These variables serve to prevent creating duplicate constants
- // and hold constant true or false values.
- Value constantTrue = nullptr;
- Value constantFalse = nullptr;
-
- for (OpOperand &use :
- llvm::make_early_inc_range(op.getCondition().getUses())) {
- if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
- changed = true;
-
- if (!constantTrue)
- constantTrue = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
-
- rewriter.modifyOpInPlace(use.getOwner(),
- [&]() { use.set(constantTrue); });
- } else if (op.getElseRegion().isAncestor(
- use.getOwner()->getParentRegion())) {
- changed = true;
-
- if (!constantFalse)
- constantFalse = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
-
- rewriter.modifyOpInPlace(use.getOwner(),
- [&]() { use.set(constantFalse); });
- }
- }
-
- return success(changed);
- }
-};
-
/// Remove any statements from an if that are equivalent to the condition
/// or its negation. For example:
///
@@ -2852,9 +2793,8 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
- ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
- RemoveStaticCondition, RemoveUnusedResults,
+ results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
+ RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
ReplaceIfYieldWithConditionOrValue>(context);
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..6d3bafbbc90e4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ForallToFor.cpp
ForallToParallel.cpp
ForToWhile.cpp
+ IfConditionPropagation.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
LoopRangeFolding.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
new file mode 100644
index 0000000000000..be8d0e805a7a4
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
@@ -0,0 +1,96 @@
+//===- IfConditionPropagation.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a pass for constant propagation of the condition of an
+// `scf.if` into its then and else regions as true and false respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions. This is done as a single
+/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
+/// traversing, the function maintains the set of visited regions to quickly
+/// identify whether the value belong to a region that is known to be nested in
+/// the `then` or `else` branch of a specific loop.
+static void propagateIfConditionsImpl(Operation *root,
+ llvm::SmallPtrSet<Region *, 8> &visited) {
+ if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
+ llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
+ // Visit the "then" region, collect children.
+ for (Block &block : scfIf.getThenRegion()) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, thenChildren);
+ }
+ }
+
+ // Visit the "else" region, collect children.
+ for (Block &block : scfIf.getElseRegion()) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, elseChildren);
+ }
+ }
+
+ // Update uses to point to constants instead.
+ OpBuilder builder(scfIf);
+ Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+ builder.getBoolAttr(true));
+ Value falseValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+ builder.getBoolAttr(false));
+
+ for (OpOperand &use : scfIf.getCondition().getUses()) {
+ if (thenChildren.contains(use.getOwner()->getParentRegion()))
+ use.set(trueValue);
+ else if (elseChildren.contains(use.getOwner()->getParentRegion()))
+ use.set(falseValue);
+ }
+ if (trueValue.getUses().empty())
+ trueValue.getDefiningOp()->erase();
+ if (falseValue.getUses().empty())
+ falseValue.getDefiningOp()->erase();
+
+ // Append the two lists of children and return them.
+ visited.insert_range(thenChildren);
+ visited.insert_range(elseChildren);
+ return;
+ }
+
+ for (Region ®ion : root->getRegions()) {
+ for (Block &block : region) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, visited);
+ }
+ }
+ }
+}
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions
+static void propagateIfConditions(Operation *root) {
+ llvm::SmallPtrSet<Region *, 8> visited;
+ propagateIfConditionsImpl(root, visited);
+}
+
+namespace {
+/// Pass entrypoint.
+struct SCFIfConditionPropagationPass
+ : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
+ void runOnOperation() override { propagateIfConditions(getOperation()); }
+};
+} // namespace
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8ba8013d008a0..12d30e17f4a8f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -867,41 +867,6 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
// -----
-// CHECK-LABEL: @cond_prop
-func.func @cond_prop(%arg0 : i1) -> index {
- %res = scf.if %arg0 -> index {
- %res1 = scf.if %arg0 -> index {
- %v1 = "test.get_some_value1"() : () -> index
- scf.yield %v1 : index
- } else {
- %v2 = "test.get_some_value2"() : () -> index
- scf.yield %v2 : index
- }
- scf.yield %res1 : index
- } else {
- %res2 = scf.if %arg0 -> index {
- %v3 = "test.get_some_value3"() : () -> index
- scf.yield %v3 : index
- } else {
- %v4 = "test.get_some_value4"() : () -> index
- scf.yield %v4 : index
- }
- scf.yield %res2 : index
- }
- return %res : index
-}
-// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index
-// CHECK-NEXT: scf.yield %[[c1]] : index
-// CHECK-NEXT: } else {
-// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index
-// CHECK-NEXT: scf.yield %[[c4]] : index
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[if]] : index
-// CHECK-NEXT:}
-
-// -----
-
// CHECK-LABEL: @replace_if_with_cond1
func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
%true = arith.constant true
diff --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir b/mlir/test/Dialect/SCF/if-cond-prop.mlir
new file mode 100644
index 0000000000000..99d113f672014
--- /dev/null
+++ b/mlir/test/Dialect/SCF/if-cond-prop.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @cond_prop
+func.func @cond_prop(%arg0 : i1) -> index {
+ %res = scf.if %arg0 -> index {
+ %res1 = scf.if %arg0 -> index {
+ %v1 = "test.get_some_value1"() : () -> index
+ scf.yield %v1 : index
+ } else {
+ %v2 = "test.get_some_value2"() : () -> index
+ scf.yield %v2 : index
+ }
+ scf.yield %res1 : index
+ } else {
+ %res2 = scf.if %arg0 -> index {
+ %v3 = "test.get_some_value3"() : () -> index
+ scf.yield %v3 : index
+ } else {
+ %v4 = "test.get_some_value4"() : () -> index
+ scf.yield %v4 : index
+ }
+ scf.yield %res2 : index
+ }
+ return %res : index
+}
+// CHECK: %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK: %[[c1:.+]] = "test.get_some_value1"() : () -> index
+// CHECK: scf.yield %[[c1]] : index
+// CHECK: } else {
+// CHECK: %[[c4:.+]] = "test.get_some_value4"() : () -> index
+// CHECK: scf.yield %[[c4]] : index
+// CHECK: }
+// CHECK: return %[[if]] : index
+// CHECK:}
|
This offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist. It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites.
| @@ -0,0 +1,34 @@ | |||
| // RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove --allow-unregistered-dialect; it shouldn't ever been needed in MLIR tests I believe (the test dialect allows for unregistered op if needed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ping @ftynse on post-merge comments.
|
That looks quite unfortunate to me to lose as a canonicalization (I don't understand your comment about locality: this does not seem like an issue to me at all actually). Can you expand on the source of the slow-down (or speed-up)? |
This offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist. It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites.
@ftynse ping on post-merge comments! |
|
|
||
| // Append the two lists of children and return them. | ||
| visited.insert_range(thenChildren); | ||
| visited.insert_range(elseChildren); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand the logic for the visited set, the comment on the function says:
/// While
/// traversing, the function maintains the set of visited regions to quickly
/// identify whether the value belong to a region that is known to be nested in
/// the `then` or `else` branch of a specific loop.
But that does not seem to be the case, the visited set is only ever used for inserting here, never checked as far as I can see?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see the other recursions above, where "visited" will populate the thenChildren and elseChildren.
But if we hit a ifOp above, why do we need to recurse below again?
…ne pass" (#159535) Reverts llvm/llvm-project#150278 Multiple post-merge comment remained undressed, and some more fundamental issues were also reported in #159165
This offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist.
It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites.