Skip to content

Commit d906426

Browse files
committed
[mlir] make transform.loop.outline also return the call handle
Outlining is particularly interesting when the outlined function is replaced with something else, e.g., a microkernel. It is good to have a handle to the call in this case. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D149849
1 parent 3b6bc87 commit d906426

File tree

6 files changed

+36
-24
lines changed

6 files changed

+36
-24
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,31 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
4646
DeclareOpInterfaceMethods<TransformOpInterface>]> {
4747
let summary = "Outlines a loop into a named function";
4848
let description = [{
49-
Moves the loop into a separate function with the specified name and
50-
replaces the loop in the Payload IR with a call to that function. Takes
51-
care of forwarding values that are used in the loop as function arguments.
52-
If the operand is associated with more than one loop, each loop will be
53-
outlined into a separate function. The provided name is used as a _base_
54-
for forming actual function names following SymbolTable auto-renaming
55-
scheme to avoid duplicate symbols. Expects that all ops in the Payload IR
56-
have a SymbolTable ancestor (typically true because of the top-level
57-
module). Returns the handle to the list of outlined functions in the same
58-
order as the operand handle.
49+
Moves the loop into a separate function with the specified name and replaces
50+
the loop in the Payload IR with a call to that function. Takes care of
51+
forwarding values that are used in the loop as function arguments. If the
52+
operand is associated with more than one loop, each loop will be outlined
53+
into a separate function. The provided name is used as a _base_ for forming
54+
actual function names following `SymbolTable` auto-renaming scheme to avoid
55+
duplicate symbols. Expects that all ops in the Payload IR have a
56+
`SymbolTable` ancestor (typically true because of the top-level module).
57+
58+
#### Return Modes
59+
60+
Returns a handle to the list of outlined functions and a handle to the
61+
corresponding function call operations in the same order as the operand
62+
handle.
63+
64+
Produces a definite failure if outlining failed for any of the targets.
5965
}];
6066

6167
// Note that despite the name of the transform operation and related utility
6268
// functions, the actual implementation does not require the operation to be
6369
// a loop.
6470
let arguments = (ins TransformHandleTypeInterface:$target,
6571
StrAttr:$func_name);
66-
let results = (outs TransformHandleTypeInterface:$transformed);
72+
let results = (outs TransformHandleTypeInterface:$function,
73+
TransformHandleTypeInterface:$call);
6774

6875
let assemblyFormat =
6976
"$target attr-dict `:` functional-type(operands, results)";

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
8585
DiagnosedSilenceableFailure
8686
transform::LoopOutlineOp::apply(transform::TransformResults &results,
8787
transform::TransformState &state) {
88-
SmallVector<Operation *> transformed;
88+
SmallVector<Operation *> functions;
89+
SmallVector<Operation *> calls;
8990
DenseMap<Operation *, SymbolTable> symbolTables;
9091
for (Operation *target : state.getPayloadOps(getTarget())) {
9192
Location location = target->getLoc();
@@ -112,9 +113,11 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
112113
symbolTable.insert(*outlined);
113114
call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
114115
}
115-
transformed.push_back(*outlined);
116+
functions.push_back(*outlined);
117+
calls.push_back(call);
116118
}
117-
results.set(getTransformed().cast<OpResult>(), transformed);
119+
results.set(getFunction().cast<OpResult>(), functions);
120+
results.set(getCall().cast<OpResult>(), calls);
118121
return DiagnosedSilenceableFailure::success();
119122
}
120123

mlir/python/mlir/dialects/_loop_transform_ops_ext.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,17 @@ class LoopOutlineOp:
3939

4040
def __init__(
4141
self,
42-
result_type: Type,
42+
function_type: Type,
43+
call_type: Type,
4344
target: Union[Operation, Value],
4445
*,
4546
func_name: Union[str, StringAttr],
4647
ip=None,
4748
loc=None,
4849
):
4950
super().__init__(
50-
result_type,
51+
function_type,
52+
call_type,
5153
_get_op_result_or_value(target),
5254
func_name=(func_name if isinstance(func_name, StringAttr) else
5355
StringAttr.get(func_name)),

mlir/test/Dialect/SCF/transform-ops-invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ func.func @loop_outline_op_multi_region() {
5454
}
5555

5656
transform.sequence failures(propagate) {
57-
^bb1(%arg1: !pdl.operation):
58-
%0 = transform.structured.match ops{["scf.while"]} in %arg1 : (!pdl.operation) -> !pdl.operation
57+
^bb1(%arg1: !transform.any_op):
58+
%0 = transform.structured.match ops{["scf.while"]} in %arg1 : (!transform.any_op) -> !transform.any_op
5959
// expected-error @below {{failed to outline}}
60-
transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
60+
transform.loop.outline %0 {func_name = "foo"} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
6161
}

mlir/test/Dialect/SCF/transform-ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) {
7575
}
7676

7777
transform.sequence failures(propagate) {
78-
^bb1(%arg1: !pdl.operation):
79-
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!pdl.operation) -> !pdl.operation
80-
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
78+
^bb1(%arg1: !transform.any_op):
79+
%0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
80+
%1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for">
8181
// CHECK: = transform.loop.outline %{{.*}}
82-
transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> !pdl.operation
82+
transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
8383
}
8484

8585
// -----

mlir/test/python/dialects/transform_loop_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def loopOutline():
3333
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
3434
[], transform.OperationType.get("scf.for"))
3535
with InsertionPoint(sequence.body):
36-
loop.LoopOutlineOp(pdl.OperationType.get(), sequence.bodyTarget, func_name="foo")
36+
loop.LoopOutlineOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, func_name="foo")
3737
transform.YieldOp()
3838
# CHECK-LABEL: TEST: loopOutline
3939
# CHECK: = transform.loop.outline %

0 commit comments

Comments
 (0)