-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][CF] Add ub.unreachable canonicalization
#169873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][CF] Add ub.unreachable canonicalization
#169873
Conversation
|
@llvm/pr-subscribers-mlir-ub @llvm/pr-subscribers-mlir-cf Author: Matthias Springer (matthias-springer) ChangesBasic blocks with a Depends on #169872. Full diff: https://github.com/llvm/llvm-project/pull/169873.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index a441fd82546e3..c9b4da44ffa01 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ControlFlow_Dialect : Dialect {
let name = "cf";
let cppNamespace = "::mlir::cf";
- let dependentDialects = ["arith::ArithDialect"];
+ let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
let description = [{
This dialect contains low-level, i.e. non-region based, control flow
constructs. These constructs generally represent control flow directly
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..02081e2d6d15f 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -9,6 +9,10 @@
#ifndef MLIR_DIALECT_UB_IR_OPS_H
#define MLIR_DIALECT_UB_IR_OPS_H
+namespace mlir {
+class PatternRewriter;
+}
+
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index 8a354da2db10c..c1d74290ec174 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -84,6 +84,7 @@ def UnreachableOp : UB_Op<"unreachable", [Terminator]> {
}];
let assemblyFormat = "attr-dict";
+ let hasCanonicalizeMethod = 1;
}
#endif // MLIR_DIALECT_UB_IR_UBOPS_TD
diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
index 58551bb435c86..05a787fa53ec3 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect
MLIRControlFlowInterfaces
MLIRIR
MLIRSideEffectInterfaces
+ MLIRUBDialect
)
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index f1da1a125e9ef..aabf8930cf78e 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -445,6 +446,35 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
return success(replaced);
}
};
+
+struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
+ // If the "true" destination has unreachable an unreachable terminator,
+ // always branch to the "false" destination.
+ Block *trueDest = condbr.getTrueDest();
+ Block *falseDest = condbr.getFalseDest();
+ if (llvm::hasSingleElement(*trueDest) &&
+ isa<ub::UnreachableOp>(trueDest->getTerminator())) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
+ condbr.getFalseOperands());
+ return success();
+ }
+
+ // If the "false" destination has unreachable an unreachable terminator,
+ // always branch to the "true" destination.
+ if (llvm::hasSingleElement(*falseDest) &&
+ isa<ub::UnreachableOp>(falseDest->getTerminator())) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
+ condbr.getTrueOperands());
+ return success();
+ }
+
+ return failure();
+ }
+};
} // namespace
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -452,7 +482,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
- CondBranchTruthPropagation>(context);
+ CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
}
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..419e3f9d76fb2 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/UB/IR/UBOpsDialect.cpp.inc"
@@ -57,8 +58,33 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// PoisonOp
+//===----------------------------------------------------------------------===//
+
OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
+//===----------------------------------------------------------------------===//
+// UnreachableOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreachableOp::canonicalize(UnreachableOp unreachableOp,
+ PatternRewriter &rewriter) {
+ Block *block = unreachableOp->getBlock();
+ if (llvm::hasSingleElement(*block))
+ return rewriter.notifyMatchFailure(
+ unreachableOp, "unreachable op is the only operation in the block");
+
+ // Erase all other operations in the block. They must be dead.
+ for (Operation &op : llvm::make_early_inc_range(*block)) {
+ if (&op == unreachableOp.getOperation())
+ continue;
+ op.dropAllUses();
+ rewriter.eraseOp(&op);
+ }
+ return success();
+}
+
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 17f7d28ba59fb..75dec6dacde91 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -634,3 +634,28 @@ func.func @unsimplified_cycle_2(%c : i1) {
^bb7:
cf.br ^bb6
}
+
+// CHECK-LABEL: @drop_unreachable_branch_1
+// CHECK-NEXT: "test.foo"() : () -> ()
+// CHECK-NEXT: return
+func.func @drop_unreachable_branch_1(%c: i1) {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ "test.foo"() : () -> ()
+ return
+^bb2:
+ "test.bar"() : () -> ()
+ ub.unreachable
+}
+
+// CHECK-LABEL: @drop_unreachable_branch_2
+// CHECK-NEXT: ub.unreachable
+func.func @drop_unreachable_branch_2(%c: i1) {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ "test.foo"() : () -> ()
+ ub.unreachable
+^bb2:
+ "test.bar"() : () -> ()
+ ub.unreachable
+}
diff --git a/mlir/test/Dialect/UB/canonicalize.mlir b/mlir/test/Dialect/UB/canonicalize.mlir
index c3f286e49b09b..74ba9f1932384 100644
--- a/mlir/test/Dialect/UB/canonicalize.mlir
+++ b/mlir/test/Dialect/UB/canonicalize.mlir
@@ -9,3 +9,13 @@ func.func @merge_poison() -> (i32, i32) {
%1 = ub.poison : i32
return %0, %1 : i32, i32
}
+
+// -----
+
+// CHECK-LABEL: func @drop_ops_before_unreachable()
+// CHECK-NEXT: ub.unreachable
+func.func @drop_ops_before_unreachable() {
+ "test.foo"() : () -> ()
+ "test.bar"() : () -> ()
+ ub.unreachable
+}
|
bc96208 to
561b6ca
Compare
mlir/lib/Dialect/UB/IR/UBOps.cpp
Outdated
| continue; | ||
| op.dropAllUses(); | ||
| rewriter.eraseOp(&op); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should TODO that:
- this assumes we don't have calls that are "no return".
- this assumes that loops terminates.
- this assumes that nothing interrupts the control-flow (which is fine until early-exit is added).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should have a trait "AlwaysForwardProgress" on operations to allow this kind of transformations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't think of infinite loops. I am wondering if this canonicalization is actually safe. I'm just merging the cf.cond_br canonicalization for now, so we can discuss further. I feel like this kind of "canonicalization" may be better suited for a separate pass. (Something like -remove-dead-values maybe?) We could also make this optimization customizable (turn off/on via pass option).
561b6ca to
92caacb
Compare
92caacb to
143a40b
Compare
ub.unreachable canonicalizationub.unreachable canonicalization
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/42632 Here is the relevant piece of the build log for the reference |
A new dependency was added.
Basic blocks with only a `ub.unreachable` terminator are unreachable. This commit adds a canonicalization pattern that folds to `cf.cond_br` to `cf.br` if one of the destination branches is unreachable.
A new dependency was added.
Basic blocks with only a `ub.unreachable` terminator are unreachable. This commit adds a canonicalization pattern that folds to `cf.cond_br` to `cf.br` if one of the destination branches is unreachable.
A new dependency was added.
Basic blocks with only a `ub.unreachable` terminator are unreachable. This commit adds a canonicalization pattern that folds to `cf.cond_br` to `cf.br` if one of the destination branches is unreachable.
A new dependency was added.
Basic blocks with only a `ub.unreachable` terminator are unreachable. This commit adds a canonicalization pattern that folds to `cf.cond_br` to `cf.br` if one of the destination branches is unreachable.
A new dependency was added.
Basic blocks with only a
ub.unreachableterminator are unreachable. This commit adds a canonicalization pattern that folds tocf.cond_brtocf.brif one of the destination branches is unreachable.