Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> {
let constructor = "mlir::createForLoopSpecializationPass()";
}

def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> {
let summary = "Replace usages of if condition with true/false constants in "
"the conditional regions";
let dependentDialects = ["arith::ArithDialect"];
}

def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
let summary = "Fuse adjacent parallel loops";
let constructor = "mlir::createParallelLoopFusionPass()";
Expand Down
64 changes: 2 additions & 62 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2414,65 +2414,6 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
}
};

/// Allow the true region of an if to assume the condition is true
/// and vice versa. For example:
///
/// scf.if %cmp {
/// print(%cmp)
/// }
///
/// becomes
///
/// scf.if %cmp {
/// print(true)
/// }
///
struct ConditionPropagation : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
// Early exit if the condition is constant since replacing a constant
// in the body with another constant isn't a simplification.
if (matchPattern(op.getCondition(), m_Constant()))
return failure();

bool changed = false;
mlir::Type i1Ty = rewriter.getI1Type();

// These variables serve to prevent creating duplicate constants
// and hold constant true or false values.
Value constantTrue = nullptr;
Value constantFalse = nullptr;

for (OpOperand &use :
llvm::make_early_inc_range(op.getCondition().getUses())) {
if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
changed = true;

if (!constantTrue)
constantTrue = rewriter.create<arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));

rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantTrue); });
} else if (op.getElseRegion().isAncestor(
use.getOwner()->getParentRegion())) {
changed = true;

if (!constantFalse)
constantFalse = rewriter.create<arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));

rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantFalse); });
}
}

return success(changed);
}
};

/// Remove any statements from an if that are equivalent to the condition
/// or its negation. For example:
///
Expand Down Expand Up @@ -2854,9 +2795,8 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {

void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
RemoveStaticCondition, RemoveUnusedResults,
results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
ReplaceIfYieldWithConditionOrValue>(context);
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ForallToFor.cpp
ForallToParallel.cpp
ForToWhile.cpp
IfConditionPropagation.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
LoopRangeFolding.cpp
Expand Down
98 changes: 98 additions & 0 deletions mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
//===- IfConditionPropagation.cpp -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains a pass for constant propagation of the condition of an
// `scf.if` into its then and else regions as true and false respectively.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"

using namespace mlir;

namespace mlir {
#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir

/// Traverses the IR recursively (on region tree) and updates the uses of a
/// value also as the condition of an `scf.if` to either `true` or `false`
/// constants in the `then` and `else regions. This is done as a single
/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
/// traversing, the function maintains the set of visited regions to quickly
/// identify whether the value belong to a region that is known to be nested in
/// the `then` or `else` branch of a specific loop.
static void propagateIfConditionsImpl(Operation *root,
llvm::SmallPtrSet<Region *, 8> &visited) {
if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
// Visit the "then" region, collect children.
for (Block &block : scfIf.getThenRegion()) {
for (Operation &op : block) {
propagateIfConditionsImpl(&op, thenChildren);
}
}

// Visit the "else" region, collect children.
for (Block &block : scfIf.getElseRegion()) {
for (Operation &op : block) {
propagateIfConditionsImpl(&op, elseChildren);
}
}

// Update uses to point to constants instead.
OpBuilder builder(scfIf);
Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
/*value=*/true, /*width=*/1);
Value falseValue =
arith::ConstantIntOp::create(builder, scfIf.getLoc(),
/*value=*/false, /*width=*/1);

for (OpOperand &use : scfIf.getCondition().getUses()) {
if (thenChildren.contains(use.getOwner()->getParentRegion()))
use.set(trueValue);
else if (elseChildren.contains(use.getOwner()->getParentRegion()))
use.set(falseValue);
}
if (trueValue.getUses().empty())
trueValue.getDefiningOp()->erase();
if (falseValue.getUses().empty())
falseValue.getDefiningOp()->erase();

// Append the two lists of children and return them.
visited.insert_range(thenChildren);
visited.insert_range(elseChildren);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand the logic for the visited set, the comment on the function says:

/// While
/// traversing, the function maintains the set of visited regions to quickly
/// identify whether the value belong to a region that is known to be nested in
/// the `then` or `else` branch of a specific loop.

But that does not seem to be the case, the visited set is only ever used for inserting here, never checked as far as I can see?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see the other recursions above, where "visited" will populate the thenChildren and elseChildren.

But if we hit a ifOp above, why do we need to recurse below again?

return;
}

for (Region &region : root->getRegions()) {
for (Block &block : region) {
for (Operation &op : block) {
propagateIfConditionsImpl(&op, visited);
}
}
}
}

/// Traverses the IR recursively (on region tree) and updates the uses of a
/// value also as the condition of an `scf.if` to either `true` or `false`
/// constants in the `then` and `else regions
static void propagateIfConditions(Operation *root) {
llvm::SmallPtrSet<Region *, 8> visited;
propagateIfConditionsImpl(root, visited);
}

namespace {
/// Pass entrypoint.
struct SCFIfConditionPropagationPass
: impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> {
void runOnOperation() override { propagateIfConditions(getOperation()); }
};
} // namespace
35 changes: 0 additions & 35 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -867,41 +867,6 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {

// -----

// CHECK-LABEL: @cond_prop
func.func @cond_prop(%arg0 : i1) -> index {
%res = scf.if %arg0 -> index {
%res1 = scf.if %arg0 -> index {
%v1 = "test.get_some_value1"() : () -> index
scf.yield %v1 : index
} else {
%v2 = "test.get_some_value2"() : () -> index
scf.yield %v2 : index
}
scf.yield %res1 : index
} else {
%res2 = scf.if %arg0 -> index {
%v3 = "test.get_some_value3"() : () -> index
scf.yield %v3 : index
} else {
%v4 = "test.get_some_value4"() : () -> index
scf.yield %v4 : index
}
scf.yield %res2 : index
}
return %res : index
}
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index
// CHECK-NEXT: scf.yield %[[c1]] : index
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index
// CHECK-NEXT: scf.yield %[[c4]] : index
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]] : index
// CHECK-NEXT:}

// -----

// CHECK-LABEL: @replace_if_with_cond1
func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
%true = arith.constant true
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/Dialect/SCF/if-cond-prop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove --allow-unregistered-dialect; it shouldn't ever been needed in MLIR tests I believe (the test dialect allows for unregistered op if needed).

Copy link
Collaborator

@joker-eph joker-eph Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping @ftynse on post-merge comments.


// CHECK-LABEL: @cond_prop
func.func @cond_prop(%arg0 : i1) -> index {
%res = scf.if %arg0 -> index {
%res1 = scf.if %arg0 -> index {
%v1 = "test.get_some_value1"() : () -> index
scf.yield %v1 : index
} else {
%v2 = "test.get_some_value2"() : () -> index
scf.yield %v2 : index
}
scf.yield %res1 : index
} else {
%res2 = scf.if %arg0 -> index {
%v3 = "test.get_some_value3"() : () -> index
scf.yield %v3 : index
} else {
%v4 = "test.get_some_value4"() : () -> index
scf.yield %v4 : index
}
scf.yield %res2 : index
}
return %res : index
}
// CHECK: %[[if:.+]] = scf.if %arg0 -> (index) {
// CHECK: %[[c1:.+]] = "test.get_some_value1"() : () -> index
// CHECK: scf.yield %[[c1]] : index
// CHECK: } else {
// CHECK: %[[c4:.+]] = "test.get_some_value4"() : () -> index
// CHECK: scf.yield %[[c4]] : index
// CHECK: }
// CHECK: return %[[if]] : index
// CHECK:}