diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt index 58551bb435c86..05a787fa53ec3 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt @@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect MLIRControlFlowInterfaces MLIRIR MLIRSideEffectInterfaces + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index f1da1a125e9ef..d2078d8ab5ca5 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -445,6 +446,37 @@ struct CondBranchTruthPropagation : public OpRewritePattern { return success(replaced); } }; + +/// If the destination block of a conditional branch contains only +/// ub.unreachable, unconditionally branch to the other destination. +struct DropUnreachableCondBranch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // If the "true" destination is unreachable, branch to the "false" + // destination. + Block *trueDest = condbr.getTrueDest(); + Block *falseDest = condbr.getFalseDest(); + if (llvm::hasSingleElement(*trueDest) && + isa(trueDest->getTerminator())) { + rewriter.replaceOpWithNewOp(condbr, falseDest, + condbr.getFalseOperands()); + return success(); + } + + // If the "false" destination is unreachable, branch to the "true" + // destination. + if (llvm::hasSingleElement(*falseDest) && + isa(falseDest->getTerminator())) { + rewriter.replaceOpWithNewOp(condbr, trueDest, + condbr.getTrueOperands()); + return success(); + } + + return failure(); + } +}; } // namespace void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -452,7 +484,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); + CondBranchTruthPropagation, DropUnreachableCondBranch>(context); } SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir index 17f7d28ba59fb..21a16784b81b2 100644 --- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir +++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir @@ -634,3 +634,25 @@ func.func @unsimplified_cycle_2(%c : i1) { ^bb7: cf.br ^bb6 } + +// CHECK-LABEL: @drop_unreachable_branch_1 +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: return +func.func @drop_unreachable_branch_1(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + "test.foo"() : () -> () + return +^bb2: + ub.unreachable +} + +// CHECK-LABEL: @drop_unreachable_branch_2 +// CHECK-NEXT: ub.unreachable +func.func @drop_unreachable_branch_2(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + ub.unreachable +^bb2: + ub.unreachable +}