diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index e341c3744f1d8..41410a0a56aa9 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -130,7 +130,7 @@ FailureOr mlir::outlineSingleBlockRegion(RewriterBase &rewriter, // Outline before current function. OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(region.getParentOfType()); + rewriter.setInsertionPoint(region.getParentOfType()); SetVector captures; getUsedValuesDefinedAbove(region, captures); diff --git a/mlir/test/Transforms/scf-if-utils.mlir b/mlir/test/Transforms/scf-if-utils.mlir index 449be18483741..fd59f5e9295e6 100644 --- a/mlir/test/Transforms/scf-if-utils.mlir +++ b/mlir/test/Transforms/scf-if-utils.mlir @@ -73,3 +73,30 @@ func.func @outline_empty_if_else(%cond: i1, %a: index, %b: memref, %c: i8 } return } + +// ----- + +// This test checks scf utils can work on llvm func. + +// CHECK: func @outlined_then0() { +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func @outlined_else0(%{{.*}}: i1, %{{.*}}: i32) { +// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, i32) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: llvm.func @llvm_func(%{{.*}}: i1, %{{.*}}: i32) { +// CHECK-NEXT: scf.if %{{.*}} { +// CHECK-NEXT: func.call @outlined_then0() : () -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: func.call @outlined_else0(%{{.*}}, %{{.*}}) : (i1, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } +llvm.func @llvm_func(%cond: i1, %a: i32) { + scf.if %cond { + } else { + "some_op"(%cond, %a) : (i1, i32) -> () + } + llvm.return +} diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index 3ff7f9966e93d..a3be1f94fa28a 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -79,6 +79,10 @@ struct TestSCFIfUtilsPass StringRef getDescription() const final { return "test scf.if utils"; } explicit TestSCFIfUtilsPass() = default; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { int count = 0; getOperation().walk([&](scf::IfOp ifOp) {