Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
be accessed inside the op. The op's region can have multiple blocks and the
blocks can have multiple distinct terminators. Values returned from this op's
region define the op's results.
The optional 'no_inline' flag can be set to request the ExecuteRegionOp to be
preserved as much as possible and not being inlined in the parent block until
an explicit lowering step.

Example:

Expand All @@ -98,6 +101,14 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
}
}

// the same as above but with no_inline attribute
scf.for %i = 0 to 128 step %c1 {
%y = scf.execute_region -> i32 no_inline {
%x = load %A[%i] : memref<128xi32>
scf.yield %x : i32
}
}

affine.for %i = 0 to 100 {
"foo"() : () -> ()
%v = scf.execute_region -> i64 {
Expand All @@ -119,6 +130,10 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
```
}];

let arguments = (ins
UnitAttr:$no_inline
);

let results = (outs Variadic<AnyType>);

let regions = (region AnyRegion:$region);
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
if (parser.parseOptionalArrowTypeList(result.types))
return failure();

if (succeeded(parser.parseOptionalKeyword("no_inline")))
result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());

// Introduce the body region and parse it.
Region *body = result.addRegion();
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
Expand All @@ -148,8 +151,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,

void ExecuteRegionOp::print(OpAsmPrinter &p) {
p.printOptionalArrowTypeList(getResultTypes());

p << ' ';
if (getNoInline())
p << "no_inline ";
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
Expand Down Expand Up @@ -184,7 +188,7 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {

LogicalResult matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const override {
if (!op.getRegion().hasOneBlock())
if (!op.getRegion().hasOneBlock() || op.getNoInline())
return failure();
replaceOpWithRegion(rewriter, op, op.getRegion());
return success();
Expand Down
34 changes: 28 additions & 6 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1440,8 +1440,8 @@ func.func @propagate_into_execute_region() {

// -----

// CHECK-LABEL: func @execute_region_elim
func.func @execute_region_elim() {
// CHECK-LABEL: func @execute_region_inline
func.func @execute_region_inline() {
affine.for %i = 0 to 100 {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
Expand All @@ -1461,8 +1461,30 @@ func.func @execute_region_elim() {

// -----

// CHECK-LABEL: func @func_execute_region_elim
func.func @func_execute_region_elim() {
// CHECK-LABEL: func @execute_region_no_inline
func.func @execute_region_no_inline() {
affine.for %i = 0 to 100 {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 no_inline {
%x = "test.val"() : () -> i64
scf.yield %x : i64
}
"test.bar"(%v) : (i64) -> ()
}
return
}

// CHECK-NEXT: affine.for %arg0 = 0 to 100 {
// CHECK-NEXT: "test.foo"() : () -> ()
// CHECK-NEXT: scf.execute_region
// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64
// CHECK-NEXT: scf.yield %[[VAL]] : i64
// CHECK-NEXT: }

// -----

// CHECK-LABEL: func @func_execute_region_inline
func.func @func_execute_region_inline() {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1
Expand Down Expand Up @@ -1496,8 +1518,8 @@ func.func @func_execute_region_elim() {

// -----

// CHECK-LABEL: func @func_execute_region_elim_multi_yield
func.func @func_execute_region_elim_multi_yield() {
// CHECK-LABEL: func @func_execute_region_inline_multi_yield
func.func @func_execute_region_inline_multi_yield() {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1
Expand Down