Skip to content

Conversation

@ftynse
Copy link
Member

@ftynse ftynse commented Jul 23, 2025

This offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist.

It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites.

@llvmbot
Copy link
Member

llvmbot commented Jul 23, 2025

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

This offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist.

It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites.


Full diff: https://github.com/llvm/llvm-project/pull/150278.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+6)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+2-62)
  • (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp (+96)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (-35)
  • (added) mlir/test/Dialect/SCF/if-cond-prop.mlir (+34)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..ca2510bb53af9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -41,6 +41,12 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
   let constructor = "mlir::createForLoopSpecializationPass()";
 }
 
+def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
+  let summary = "Replace usages of if condition with true/false constants in "
+                "the conditional regions";
+  let dependentDialects = ["arith::ArithDialect"];
+}
+
 def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
   let summary = "Fuse adjacent parallel loops";
   let constructor = "mlir::createParallelLoopFusionPass()";
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 00c31a1500e17..6cb61900928d6 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2412,65 +2412,6 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
   }
 };
 
-/// Allow the true region of an if to assume the condition is true
-/// and vice versa. For example:
-///
-///   scf.if %cmp {
-///      print(%cmp)
-///   }
-///
-///  becomes
-///
-///   scf.if %cmp {
-///      print(true)
-///   }
-///
-struct ConditionPropagation : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IfOp op,
-                                PatternRewriter &rewriter) const override {
-    // Early exit if the condition is constant since replacing a constant
-    // in the body with another constant isn't a simplification.
-    if (matchPattern(op.getCondition(), m_Constant()))
-      return failure();
-
-    bool changed = false;
-    mlir::Type i1Ty = rewriter.getI1Type();
-
-    // These variables serve to prevent creating duplicate constants
-    // and hold constant true or false values.
-    Value constantTrue = nullptr;
-    Value constantFalse = nullptr;
-
-    for (OpOperand &use :
-         llvm::make_early_inc_range(op.getCondition().getUses())) {
-      if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
-        changed = true;
-
-        if (!constantTrue)
-          constantTrue = rewriter.create<arith::ConstantOp>(
-              op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
-
-        rewriter.modifyOpInPlace(use.getOwner(),
-                                 [&]() { use.set(constantTrue); });
-      } else if (op.getElseRegion().isAncestor(
-                     use.getOwner()->getParentRegion())) {
-        changed = true;
-
-        if (!constantFalse)
-          constantFalse = rewriter.create<arith::ConstantOp>(
-              op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
-
-        rewriter.modifyOpInPlace(use.getOwner(),
-                                 [&]() { use.set(constantFalse); });
-      }
-    }
-
-    return success(changed);
-  }
-};
-
 /// Remove any statements from an if that are equivalent to the condition
 /// or its negation. For example:
 ///
@@ -2852,9 +2793,8 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
 
 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                        MLIRContext *context) {
-  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
-              ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
-              RemoveStaticCondition, RemoveUnusedResults,
+  results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
+              RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
               ReplaceIfYieldWithConditionOrValue>(context);
 }
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..6d3bafbbc90e4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ForallToFor.cpp
   ForallToParallel.cpp
   ForToWhile.cpp
+  IfConditionPropagation.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp
   LoopRangeFolding.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
new file mode 100644
index 0000000000000..be8d0e805a7a4
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
@@ -0,0 +1,96 @@
+//===- IfConditionPropagation.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a pass for constant propagation of the condition of an
+// `scf.if` into its then and else regions as true and false respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions. This is done as a single
+/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
+/// traversing, the function maintains the set of visited regions to quickly
+/// identify whether the value belong to a region that is known to be nested in
+/// the `then` or `else` branch of a specific loop.
+static void propagateIfConditionsImpl(Operation *root,
+                                      llvm::SmallPtrSet<Region *, 8> &visited) {
+  if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
+    llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
+    // Visit the "then" region, collect children.
+    for (Block &block : scfIf.getThenRegion()) {
+      for (Operation &op : block) {
+        propagateIfConditionsImpl(&op, thenChildren);
+      }
+    }
+
+    // Visit the "else" region, collect children.
+    for (Block &block : scfIf.getElseRegion()) {
+      for (Operation &op : block) {
+        propagateIfConditionsImpl(&op, elseChildren);
+      }
+    }
+
+    // Update uses to point to constants instead.
+    OpBuilder builder(scfIf);
+    Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+                                                   builder.getBoolAttr(true));
+    Value falseValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+                                                    builder.getBoolAttr(false));
+
+    for (OpOperand &use : scfIf.getCondition().getUses()) {
+      if (thenChildren.contains(use.getOwner()->getParentRegion()))
+        use.set(trueValue);
+      else if (elseChildren.contains(use.getOwner()->getParentRegion()))
+        use.set(falseValue);
+    }
+    if (trueValue.getUses().empty())
+      trueValue.getDefiningOp()->erase();
+    if (falseValue.getUses().empty())
+      falseValue.getDefiningOp()->erase();
+
+    // Append the two lists of children and return them.
+    visited.insert_range(thenChildren);
+    visited.insert_range(elseChildren);
+    return;
+  }
+
+  for (Region &region : root->getRegions()) {
+    for (Block &block : region) {
+      for (Operation &op : block) {
+        propagateIfConditionsImpl(&op, visited);
+      }
+    }
+  }
+}
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions
+static void propagateIfConditions(Operation *root) {
+  llvm::SmallPtrSet<Region *, 8> visited;
+  propagateIfConditionsImpl(root, visited);
+}
+
+namespace {
+/// Pass entrypoint.
+struct SCFIfConditionPropagationPass
+    : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
+  void runOnOperation() override { propagateIfConditions(getOperation()); }
+};
+} // namespace
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8ba8013d008a0..12d30e17f4a8f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -867,41 +867,6 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
 
 // -----
 
-// CHECK-LABEL: @cond_prop
-func.func @cond_prop(%arg0 : i1) -> index {
-  %res = scf.if %arg0 -> index {
-    %res1 = scf.if %arg0 -> index {
-      %v1 = "test.get_some_value1"() : () -> index
-      scf.yield %v1 : index
-    } else {
-      %v2 = "test.get_some_value2"() : () -> index
-      scf.yield %v2 : index
-    }
-    scf.yield %res1 : index
-  } else {
-    %res2 = scf.if %arg0 -> index {
-      %v3 = "test.get_some_value3"() : () -> index
-      scf.yield %v3 : index
-    } else {
-      %v4 = "test.get_some_value4"() : () -> index
-      scf.yield %v4 : index
-    }
-    scf.yield %res2 : index
-  }
-  return %res : index
-}
-// CHECK-NEXT:  %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK-NEXT:    %[[c1:.+]] = "test.get_some_value1"() : () -> index
-// CHECK-NEXT:    scf.yield %[[c1]] : index
-// CHECK-NEXT:  } else {
-// CHECK-NEXT:    %[[c4:.+]] = "test.get_some_value4"() : () -> index
-// CHECK-NEXT:    scf.yield %[[c4]] : index
-// CHECK-NEXT:  }
-// CHECK-NEXT:  return %[[if]] : index
-// CHECK-NEXT:}
-
-// -----
-
 // CHECK-LABEL: @replace_if_with_cond1
 func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
   %true = arith.constant true
diff --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir b/mlir/test/Dialect/SCF/if-cond-prop.mlir
new file mode 100644
index 0000000000000..99d113f672014
--- /dev/null
+++ b/mlir/test/Dialect/SCF/if-cond-prop.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @cond_prop
+func.func @cond_prop(%arg0 : i1) -> index {
+  %res = scf.if %arg0 -> index {
+    %res1 = scf.if %arg0 -> index {
+      %v1 = "test.get_some_value1"() : () -> index
+      scf.yield %v1 : index
+    } else {
+      %v2 = "test.get_some_value2"() : () -> index
+      scf.yield %v2 : index
+    }
+    scf.yield %res1 : index
+  } else {
+    %res2 = scf.if %arg0 -> index {
+      %v3 = "test.get_some_value3"() : () -> index
+      scf.yield %v3 : index
+    } else {
+      %v4 = "test.get_some_value4"() : () -> index
+      scf.yield %v4 : index
+    }
+    scf.yield %res2 : index
+  }
+  return %res : index
+}
+// CHECK:  %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK:    %[[c1:.+]] = "test.get_some_value1"() : () -> index
+// CHECK:    scf.yield %[[c1]] : index
+// CHECK:  } else {
+// CHECK:    %[[c4:.+]] = "test.get_some_value4"() : () -> index
+// CHECK:    scf.yield %[[c4]] : index
+// CHECK:  }
+// CHECK:  return %[[if]] : index
+// CHECK:}

@llvmbot
Copy link
Member

llvmbot commented Jul 23, 2025

@llvm/pr-subscribers-mlir-scf

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

This offers a significant speedup over running this as a canonicalizaiton pattern, up to 10x improvement when running on large (>100k operations) inputs coming from Polygeist.

It is also not clear whether this transformation is a reasonable canonicalization as it performs non-local rewrites.


Full diff: https://github.com/llvm/llvm-project/pull/150278.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+6)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+2-62)
  • (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp (+96)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (-35)
  • (added) mlir/test/Dialect/SCF/if-cond-prop.mlir (+34)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..ca2510bb53af9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -41,6 +41,12 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
   let constructor = "mlir::createForLoopSpecializationPass()";
 }
 
+def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
+  let summary = "Replace usages of if condition with true/false constants in "
+                "the conditional regions";
+  let dependentDialects = ["arith::ArithDialect"];
+}
+
 def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
   let summary = "Fuse adjacent parallel loops";
   let constructor = "mlir::createParallelLoopFusionPass()";
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 00c31a1500e17..6cb61900928d6 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2412,65 +2412,6 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
   }
 };
 
-/// Allow the true region of an if to assume the condition is true
-/// and vice versa. For example:
-///
-///   scf.if %cmp {
-///      print(%cmp)
-///   }
-///
-///  becomes
-///
-///   scf.if %cmp {
-///      print(true)
-///   }
-///
-struct ConditionPropagation : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IfOp op,
-                                PatternRewriter &rewriter) const override {
-    // Early exit if the condition is constant since replacing a constant
-    // in the body with another constant isn't a simplification.
-    if (matchPattern(op.getCondition(), m_Constant()))
-      return failure();
-
-    bool changed = false;
-    mlir::Type i1Ty = rewriter.getI1Type();
-
-    // These variables serve to prevent creating duplicate constants
-    // and hold constant true or false values.
-    Value constantTrue = nullptr;
-    Value constantFalse = nullptr;
-
-    for (OpOperand &use :
-         llvm::make_early_inc_range(op.getCondition().getUses())) {
-      if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
-        changed = true;
-
-        if (!constantTrue)
-          constantTrue = rewriter.create<arith::ConstantOp>(
-              op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
-
-        rewriter.modifyOpInPlace(use.getOwner(),
-                                 [&]() { use.set(constantTrue); });
-      } else if (op.getElseRegion().isAncestor(
-                     use.getOwner()->getParentRegion())) {
-        changed = true;
-
-        if (!constantFalse)
-          constantFalse = rewriter.create<arith::ConstantOp>(
-              op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
-
-        rewriter.modifyOpInPlace(use.getOwner(),
-                                 [&]() { use.set(constantFalse); });
-      }
-    }
-
-    return success(changed);
-  }
-};
-
 /// Remove any statements from an if that are equivalent to the condition
 /// or its negation. For example:
 ///
@@ -2852,9 +2793,8 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
 
 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                        MLIRContext *context) {
-  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
-              ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
-              RemoveStaticCondition, RemoveUnusedResults,
+  results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
+              RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
               ReplaceIfYieldWithConditionOrValue>(context);
 }
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..6d3bafbbc90e4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ForallToFor.cpp
   ForallToParallel.cpp
   ForToWhile.cpp
+  IfConditionPropagation.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp
   LoopRangeFolding.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
new file mode 100644
index 0000000000000..be8d0e805a7a4
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
@@ -0,0 +1,96 @@
+//===- IfConditionPropagation.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a pass for constant propagation of the condition of an
+// `scf.if` into its then and else regions as true and false respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions. This is done as a single
+/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
+/// traversing, the function maintains the set of visited regions to quickly
+/// identify whether the value belong to a region that is known to be nested in
+/// the `then` or `else` branch of a specific loop.
+static void propagateIfConditionsImpl(Operation *root,
+                                      llvm::SmallPtrSet<Region *, 8> &visited) {
+  if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
+    llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
+    // Visit the "then" region, collect children.
+    for (Block &block : scfIf.getThenRegion()) {
+      for (Operation &op : block) {
+        propagateIfConditionsImpl(&op, thenChildren);
+      }
+    }
+
+    // Visit the "else" region, collect children.
+    for (Block &block : scfIf.getElseRegion()) {
+      for (Operation &op : block) {
+        propagateIfConditionsImpl(&op, elseChildren);
+      }
+    }
+
+    // Update uses to point to constants instead.
+    OpBuilder builder(scfIf);
+    Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+                                                   builder.getBoolAttr(true));
+    Value falseValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+                                                    builder.getBoolAttr(false));
+
+    for (OpOperand &use : scfIf.getCondition().getUses()) {
+      if (thenChildren.contains(use.getOwner()->getParentRegion()))
+        use.set(trueValue);
+      else if (elseChildren.contains(use.getOwner()->getParentRegion()))
+        use.set(falseValue);
+    }
+    if (trueValue.getUses().empty())
+      trueValue.getDefiningOp()->erase();
+    if (falseValue.getUses().empty())
+      falseValue.getDefiningOp()->erase();
+
+    // Append the two lists of children and return them.
+    visited.insert_range(thenChildren);
+    visited.insert_range(elseChildren);
+    return;
+  }
+
+  for (Region &region : root->getRegions()) {
+    for (Block &block : region) {
+      for (Operation &op : block) {
+        propagateIfConditionsImpl(&op, visited);
+      }
+    }
+  }
+}
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions
+static void propagateIfConditions(Operation *root) {
+  llvm::SmallPtrSet<Region *, 8> visited;
+  propagateIfConditionsImpl(root, visited);
+}
+
+namespace {
+/// Pass entrypoint.
+struct SCFIfConditionPropagationPass
+    : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
+  void runOnOperation() override { propagateIfConditions(getOperation()); }
+};
+} // namespace
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8ba8013d008a0..12d30e17f4a8f 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -867,41 +867,6 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
 
 // -----
 
-// CHECK-LABEL: @cond_prop
-func.func @cond_prop(%arg0 : i1) -> index {
-  %res = scf.if %arg0 -> index {
-    %res1 = scf.if %arg0 -> index {
-      %v1 = "test.get_some_value1"() : () -> index
-      scf.yield %v1 : index
-    } else {
-      %v2 = "test.get_some_value2"() : () -> index
-      scf.yield %v2 : index
-    }
-    scf.yield %res1 : index
-  } else {
-    %res2 = scf.if %arg0 -> index {
-      %v3 = "test.get_some_value3"() : () -> index
-      scf.yield %v3 : index
-    } else {
-      %v4 = "test.get_some_value4"() : () -> index
-      scf.yield %v4 : index
-    }
-    scf.yield %res2 : index
-  }
-  return %res : index
-}
-// CHECK-NEXT:  %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK-NEXT:    %[[c1:.+]] = "test.get_some_value1"() : () -> index
-// CHECK-NEXT:    scf.yield %[[c1]] : index
-// CHECK-NEXT:  } else {
-// CHECK-NEXT:    %[[c4:.+]] = "test.get_some_value4"() : () -> index
-// CHECK-NEXT:    scf.yield %[[c4]] : index
-// CHECK-NEXT:  }
-// CHECK-NEXT:  return %[[if]] : index
-// CHECK-NEXT:}
-
-// -----
-
 // CHECK-LABEL: @replace_if_with_cond1
 func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
   %true = arith.constant true
diff --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir b/mlir/test/Dialect/SCF/if-cond-prop.mlir
new file mode 100644
index 0000000000000..99d113f672014
--- /dev/null
+++ b/mlir/test/Dialect/SCF/if-cond-prop.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @cond_prop
+func.func @cond_prop(%arg0 : i1) -> index {
+  %res = scf.if %arg0 -> index {
+    %res1 = scf.if %arg0 -> index {
+      %v1 = "test.get_some_value1"() : () -> index
+      scf.yield %v1 : index
+    } else {
+      %v2 = "test.get_some_value2"() : () -> index
+      scf.yield %v2 : index
+    }
+    scf.yield %res1 : index
+  } else {
+    %res2 = scf.if %arg0 -> index {
+      %v3 = "test.get_some_value3"() : () -> index
+      scf.yield %v3 : index
+    } else {
+      %v4 = "test.get_some_value4"() : () -> index
+      scf.yield %v4 : index
+    }
+    scf.yield %res2 : index
+  }
+  return %res : index
+}
+// CHECK:  %[[if:.+]] = scf.if %arg0 -> (index) {
+// CHECK:    %[[c1:.+]] = "test.get_some_value1"() : () -> index
+// CHECK:    scf.yield %[[c1]] : index
+// CHECK:  } else {
+// CHECK:    %[[c4:.+]] = "test.get_some_value4"() : () -> index
+// CHECK:    scf.yield %[[c4]] : index
+// CHECK:  }
+// CHECK:  return %[[if]] : index
+// CHECK:}

This offers a significant speedup over running this as a canonicalizaiton
pattern, up to 10x improvement when running on large (>100k operations) inputs
coming from Polygeist.

It is also not clear whether this transformation is a reasonable
canonicalization as it performs non-local rewrites.
@ftynse ftynse merged commit 9d11acc into llvm:main Jul 23, 2025
4 checks passed
@ftynse ftynse deleted the if-cond-prop branch July 23, 2025 19:02
@@ -0,0 +1,34 @@
// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove --allow-unregistered-dialect; it shouldn't ever been needed in MLIR tests I believe (the test dialect allows for unregistered op if needed).

Copy link
Collaborator

@joker-eph joker-eph Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping @ftynse on post-merge comments.

@joker-eph
Copy link
Collaborator

That looks quite unfortunate to me to lose as a canonicalization (I don't understand your comment about locality: this does not seem like an issue to me at all actually).

Can you expand on the source of the slow-down (or speed-up)?
Can you also provide a synthetic test cases to illustrate the problem?

mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
This offers a significant speedup over running this as a
canonicalizaiton pattern, up to 10x improvement when running on large
(>100k operations) inputs coming from Polygeist.

It is also not clear whether this transformation is a reasonable
canonicalization as it performs non-local rewrites.
@joker-eph
Copy link
Collaborator

That looks quite unfortunate to me to lose as a canonicalization (I don't understand your comment about locality: this does not seem like an issue to me at all actually).

Can you expand on the source of the slow-down (or speed-up)? Can you also provide a synthetic test cases to illustrate the problem?

@ftynse ping on post-merge comments!
Seems to me that this was hastily merged, which is OK only because we expect post-merge comments to be addressed promptly (or the PR reverted while we discuss).

joker-eph added a commit that referenced this pull request Sep 17, 2025

// Append the two lists of children and return them.
visited.insert_range(thenChildren);
visited.insert_range(elseChildren);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand the logic for the visited set, the comment on the function says:

/// While
/// traversing, the function maintains the set of visited regions to quickly
/// identify whether the value belong to a region that is known to be nested in
/// the `then` or `else` branch of a specific loop.

But that does not seem to be the case, the visited set is only ever used for inserting here, never checked as far as I can see?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see the other recursions above, where "visited" will populate the thenChildren and elseChildren.

But if we hit a ifOp above, why do we need to recurse below again?

joker-eph added a commit that referenced this pull request Sep 18, 2025
…159535)

Reverts #150278

Multiple post-merge comment remained undressed, and some more
fundamental issues were also reported in #159165
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 18, 2025
…ne pass" (#159535)

Reverts llvm/llvm-project#150278

Multiple post-merge comment remained undressed, and some more
fundamental issues were also reported in #159165
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants