Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 48 additions & 32 deletions mlir/docs/Tutorials/transform/Ch2.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ This will generate two files, `MyExtension.h.inc` and `MyExtension.cpp.inc`, tha
```c++
// In MyExtension.cpp.

#include "MyExtension.h."

#define GET_OP_CLASSES
#include "MyExtension.cpp.inc"

Expand Down Expand Up @@ -245,7 +247,8 @@ must be modified with the provided rewriter.
return diag;
}

updateCallee(call, getNewTarget());
// Use rewriter to modify the callee in place.
rewriter.modifyOpInPlace(call, [&]() { call.setCallee(getNewTarget()); });
}

// If everything went well, return success.
Expand All @@ -263,7 +266,7 @@ void ChangeCallTargetOp::getEffects(
// Indicate that the `call` handle is only read by this operation because the
// associated operation is not erased but rather modified in-place, so the
// reference to it remains valid.
onlyReadsHandle(getCall(), effects);
onlyReadsHandle(this->getOperation()->getOpOperands().front(), effects);

// Indicate that the payload is modified by this operation.
modifiesPayload(effects);
Expand All @@ -288,67 +291,80 @@ After registering the extension, it becomes possible to use our new operation in
```mlir
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elementwise">) {
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elementwise">) {
// Since the %arg2 handle is associated with both elementwise operations,
// we need to split it into two handles so we can target only the second
// elementwise operation.
%add, %max = transform.split_handle %arg2
: (!transform.op<"linalg.elementwise">)
-> (!transform.any_op, !transform.any_op)
-> (!transform.any_op, !transform.any_op)

// The actual tiling transformation takes tile sizes as attributes. It
// produces a handle to the loop generated during tiling.
%loop, %tiled = transform.structured.tile_using_forall %max
%tiled, %loop = transform.structured.tile_using_forall %max
tile_sizes [8, 32]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)

// We can now fuse the other operations into the loop. Here, we fuse
// operations one-by-one. This requires the operation that is being fused
// to define the value used within the loop, so the order of such fusions
// is important. We could also use "transform.merge_handles" to obtain
// a single handle to all operations and give it to
// `fuse_into_containing_op` that would take care of the ordering in this
// case.
%add_fused = transform.structured.fuse_into_containing_op %add into %loop
: (!transform.any_op, !transform.any_op) -> !transform.any_op
%matmul_fused = transform.structured.fuse_into_containing_op %arg1
into %loop
: (!transform.op<"linalg.matmul">, !transform.any_op)
-> !transform.any_op
// operations one by one. This requires the operation that is being fused to
// define the value used within the loop, so the order of such fusions is
// important. We could also use "transform.merge_handles" to obtain a single
// handle to all operations and give it to `fuse_into_containing_op` that
// would take care of the ordering in this case.
%add_fused, %loop_0 =
transform.structured.fuse_into_containing_op %add into %loop
: (!transform.any_op, !transform.any_op)
-> (!transform.any_op, !transform.any_op)
%matmul_fused, %loop_1 =
transform.structured.fuse_into_containing_op %arg1 into %loop_0
: (!transform.op<"linalg.matmul">, !transform.any_op)
-> (!transform.any_op, !transform.any_op)

// Tile again to get the desired size. Note that this time this tiles the
// "add" operation and fuses matmul into the loop, but doesn't affect the
// "max" operation. This illustrates the precise targeting with the
// transform dialect. Otherwise, it is difficult to differentiate "add" and
// "max", both of which having the same kind.
%loop_2, %tiled_2 = transform.structured.tile_using_forall %add_fused
tile_sizes [4, 4]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused
into %loop_2
: (!transform.any_op, !transform.any_op) -> !transform.any_op
%tiled_2, %loop_2 =
transform.structured.tile_using_forall %add_fused tile_sizes [4, 4]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%matmul_fused_2, %loop_3 =
transform.structured.fuse_into_containing_op %matmul_fused into %loop_2
: (!transform.any_op, !transform.any_op)
-> (!transform.any_op, !transform.any_op)

// Since outlining is currently only implemented for region-holding
// operations such as loops, use tiling to size 1 to materialize the outer
// loop that is going to be outlined.
%outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target
: (!transform.any_op, !transform.any_op) -> !transform.any_op
%_, %outline_target =
transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.fuse_into_containing_op %matmul_fused_2
into %outline_target
: (!transform.any_op, !transform.any_op)
-> (!transform.any_op, !transform.any_op)
%func, %call = transform.loop.outline %outline_target
{func_name = "outlined"}
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
: (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">)

// Rewrite the call target.
transform.my.change_call_target %call, "microkernel" : !transform.any_op

transform.my.change_call_target %call, "microkernel" : !transform.op<"func.call">
transform.yield
}
}
```

When you run it with the interpreter, it produces the following error.

```
sequence.mlir:7:8: error: 'func.call' op 'microkernel' does not reference a valid function
%1 = linalg.elementwise kind=#linalg.elementwise_kind<add> ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
^
sequence.mlir:7:8: note: see current operation: %39 = "func.call"(%32, %33, %34, %36, %37) <{callee = @microkernel}> : (tensor<4x512xf32>, tensor<512x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
```

## Appendix: Autogenerated Documentation

[include "Tutorials/transform/MyExtensionCh2.md"]
Loading