diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 2d15544e871b3..0c1c15b85f4c9 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -87,6 +87,9 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ be accessed inside the op. The op's region can have multiple blocks and the blocks can have multiple distinct terminators. Values returned from this op's region define the op's results. + The optional 'no_inline' flag can be set to request the ExecuteRegionOp to be + preserved as much as possible and not being inlined in the parent block until + an explicit lowering step. Example: @@ -98,6 +101,14 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ } } + // the same as above but with no_inline attribute + scf.for %i = 0 to 128 step %c1 { + %y = scf.execute_region -> i32 no_inline { + %x = load %A[%i] : memref<128xi32> + scf.yield %x : i32 + } + } + affine.for %i = 0 to 100 { "foo"() : () -> () %v = scf.execute_region -> i64 { @@ -119,6 +130,10 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ ``` }]; + let arguments = (ins + UnitAttr:$no_inline + ); + let results = (outs Variadic); let regions = (region AnyRegion:$region); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 759e58b617578..0262a1b8a3893 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -137,6 +137,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, if (parser.parseOptionalArrowTypeList(result.types)) return failure(); + if (succeeded(parser.parseOptionalKeyword("no_inline"))) + result.addAttribute("no_inline", parser.getBuilder().getUnitAttr()); + // Introduce the body region and parse it. Region *body = result.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || @@ -148,8 +151,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, void ExecuteRegionOp::print(OpAsmPrinter &p) { p.printOptionalArrowTypeList(getResultTypes()); - p << ' '; + if (getNoInline()) + p << "no_inline "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); @@ -184,7 +188,7 @@ struct SingleBlockExecuteInliner : public OpRewritePattern { LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override { - if (!op.getRegion().hasOneBlock()) + if (!op.getRegion().hasOneBlock() || op.getNoInline()) return failure(); replaceOpWithRegion(rewriter, op, op.getRegion()); return success(); diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 12d30e17f4a8f..308cf150aa98e 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1440,8 +1440,8 @@ func.func @propagate_into_execute_region() { // ----- -// CHECK-LABEL: func @execute_region_elim -func.func @execute_region_elim() { +// CHECK-LABEL: func @execute_region_inline +func.func @execute_region_inline() { affine.for %i = 0 to 100 { "test.foo"() : () -> () %v = scf.execute_region -> i64 { @@ -1461,8 +1461,30 @@ func.func @execute_region_elim() { // ----- -// CHECK-LABEL: func @func_execute_region_elim -func.func @func_execute_region_elim() { +// CHECK-LABEL: func @execute_region_no_inline +func.func @execute_region_no_inline() { + affine.for %i = 0 to 100 { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 no_inline { + %x = "test.val"() : () -> i64 + scf.yield %x : i64 + } + "test.bar"(%v) : (i64) -> () + } + return +} + +// CHECK-NEXT: affine.for %arg0 = 0 to 100 { +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: scf.execute_region +// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64 +// CHECK-NEXT: scf.yield %[[VAL]] : i64 +// CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @func_execute_region_inline +func.func @func_execute_region_inline() { "test.foo"() : () -> () %v = scf.execute_region -> i64 { %c = "test.cmp"() : () -> i1 @@ -1496,8 +1518,8 @@ func.func @func_execute_region_elim() { // ----- -// CHECK-LABEL: func @func_execute_region_elim_multi_yield -func.func @func_execute_region_elim_multi_yield() { +// CHECK-LABEL: func @func_execute_region_inline_multi_yield +func.func @func_execute_region_inline_multi_yield() { "test.foo"() : () -> () %v = scf.execute_region -> i64 { %c = "test.cmp"() : () -> i1