From 3869c4477d12fc66a72818d7d03d6be0e502f60b Mon Sep 17 00:00:00 2001 From: donald chen Date: Fri, 20 Dec 2024 15:12:40 +0800 Subject: [PATCH] [mlir] fix crash when scf utils work on llvm.func fixed https://github.com/llvm/llvm-project/issues/119378 --- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 2 +- mlir/test/Transforms/scf-if-utils.mlir | 27 ++++++++++++++++++++++ mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp | 4 ++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index e341c3744f1d8..41410a0a56aa9 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -130,7 +130,7 @@ FailureOr mlir::outlineSingleBlockRegion(RewriterBase &rewriter, // Outline before current function. OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(region.getParentOfType()); + rewriter.setInsertionPoint(region.getParentOfType()); SetVector captures; getUsedValuesDefinedAbove(region, captures); diff --git a/mlir/test/Transforms/scf-if-utils.mlir b/mlir/test/Transforms/scf-if-utils.mlir index 449be18483741..fd59f5e9295e6 100644 --- a/mlir/test/Transforms/scf-if-utils.mlir +++ b/mlir/test/Transforms/scf-if-utils.mlir @@ -73,3 +73,30 @@ func.func @outline_empty_if_else(%cond: i1, %a: index, %b: memref, %c: i8 } return } + +// ----- + +// This test checks scf utils can work on llvm func. + +// CHECK: func @outlined_then0() { +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func @outlined_else0(%{{.*}}: i1, %{{.*}}: i32) { +// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, i32) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: llvm.func @llvm_func(%{{.*}}: i1, %{{.*}}: i32) { +// CHECK-NEXT: scf.if %{{.*}} { +// CHECK-NEXT: func.call @outlined_then0() : () -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: func.call @outlined_else0(%{{.*}}, %{{.*}}) : (i1, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } +llvm.func @llvm_func(%cond: i1, %a: i32) { + scf.if %cond { + } else { + "some_op"(%cond, %a) : (i1, i32) -> () + } + llvm.return +} diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index 3ff7f9966e93d..a3be1f94fa28a 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -79,6 +79,10 @@ struct TestSCFIfUtilsPass StringRef getDescription() const final { return "test scf.if utils"; } explicit TestSCFIfUtilsPass() = default; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { int count = 0; getOperation().walk([&](scf::IfOp ifOp) {