Skip to content

Commit 2d6ccc8

Browse files
committed
Move MLIR transform examples to linalg.elementwise
1 parent a64cf5c commit 2d6ccc8

File tree

11 files changed

+97
-97
lines changed

11 files changed

+97
-97
lines changed

mlir/docs/Tutorials/transform/Ch1.md

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
1919
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
2020
2121
// Elementwise addition.
22-
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
22+
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
2323
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
2424
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
2525
2626
// Elementwise max with 0 (ReLU).
2727
%c0f = arith.constant 0.0 : f32
28-
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
28+
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
2929
ins(%biased, %c0f : tensor<512x512xf32>, f32)
3030
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
3131
func.return %relued : tensor<512x512xf32>
@@ -41,7 +41,7 @@ module attributes {transform.with_named_sequence} {
4141
transform.named_sequence @__transform_main(
4242
%arg0: !transform.any_op,
4343
%arg1: !transform.op<"linalg.matmul">,
44-
%arg2: !transform.op<"linalg.elemwise_binary">):
44+
%arg2: !transform.op<"linalg.elementwise">):
4545
transform.yield
4646
}
4747
}
@@ -72,11 +72,11 @@ To check or debug a transform sequence, it is possible to print various entities
7272
transform.sequence failures(propagate) {
7373
^bb0(%arg0: !transform.any_op,
7474
%arg1: !transform.op<"linalg.matmul">,
75-
%arg2: !transform.op<"linalg.elemwise_binary">):
75+
%arg2: !transform.op<"linalg.elementwise">):
7676
transform.debug.emit_remark_at %arg1, "matmul"
7777
: !transform.op<"linalg.matmul">
7878
transform.debug.emit_remark_at %arg2, "elemwise_binaries"
79-
: !transform.op<"linalg.elemwise_binary">
79+
: !transform.op<"linalg.elementwise">
8080
transform.yield
8181
}
8282
```
@@ -89,24 +89,24 @@ Since we don’t want to recompile the compiler every time we change a transform
8989
```sh
9090
$ mlir-opt sequence.mlir --pass-pipeline="
9191
builtin.module(transform-interpreter{
92-
debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary})"
92+
debug-bind-trailing-args=linalg.matmul,linalg.elementwise})"
9393
```
9494

95-
The `sequence.mlir` file contains _both_ the payload IR function _and_ the transform IR sequence nested in the same module. The transform interpreter pass will apply the `@__transform_main` named sequence to the anchor operation of the pass. In our case, we also asked the interpreter pass to associate the two extra arguments of the top-level sequence with all `linalg.matmul` and `linalg.elemwise_binary` payload operations through the respective pass options. Running this pass results in the expected remarks:
95+
The `sequence.mlir` file contains _both_ the payload IR function _and_ the transform IR sequence nested in the same module. The transform interpreter pass will apply the `@__transform_main` named sequence to the anchor operation of the pass. In our case, we also asked the interpreter pass to associate the two extra arguments of the top-level sequence with all `linalg.matmul` and `linalg.elementwise` payload operations through the respective pass options. Running this pass results in the expected remarks:
9696

9797
```sh
9898
sequence.mlir:7:13: remark: matmul
9999
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
100100
^
101101
sequence.mlir:7:13: note: see current operation: %0 = linalg.matmul ins(%arg0, %arg1 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
102102
sequence.mlir:10:13: remark: elemwise_binaries
103-
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
103+
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
104104
^
105-
sequence.mlir:10:13: note: see current operation: %1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
105+
sequence.mlir:10:13: note: see current operation: %1 = linalg.elementwise kind=#linalg.elementwise_kind<add>> ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
106106
sequence.mlir:14:13: remark: elemwise_binaries
107-
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
107+
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
108108
^
109-
sequence.mlir:14:13: note: see current operation: %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<max_signed>} ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
109+
sequence.mlir:14:13: note: see current operation: %2 = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>> ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
110110
```
111111

112112
Note that `%arg2` is associated with both elementwise payload operations. Any handle is associated with a list of entities. Individual transformations may or may not care about the order of elements in that list.
@@ -121,7 +121,7 @@ module attributes {transform.with_named_sequence} {
121121
transform.named_sequence @__transform_main(
122122
%arg0: !transform.any_op,
123123
%arg1: !transform.op<"linalg.matmul">,
124-
%arg2: !transform.op<"linalg.elemwise_binary">) {
124+
%arg2: !transform.op<"linalg.elementwise">) {
125125
// The actual tiling transformation takes tile sizes as attributes.
126126
%loop, %tiled = transform.structured.tile_using_forall %arg1
127127
tile_sizes [4, 32]
@@ -163,10 +163,10 @@ func.func @fc_relu(%arg0: tensor<512x512xf32>,
163163
: tensor<4x32xf32> into tensor<512x512xf32>
164164
}
165165
}
166-
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
166+
%1 = linalg.elementwise kind=#linalg.elementwise_kind<add>>
167167
ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>)
168168
outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
169-
%2 = linalg.elemwise_binary {fun = #linalg.binary_fn<max_signed>}
169+
%2 = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>>
170170
ins(%1, %cst : tensor<512x512xf32>, f32)
171171
outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
172172
return %2 : tensor<512x512xf32>
@@ -185,7 +185,7 @@ module attributes {transform.with_named_sequence} {
185185
transform.named_sequence @__transform_main(
186186
%arg0: !transform.any_op,
187187
%arg1: !transform.op<"linalg.matmul">,
188-
%arg2: !transform.op<"linalg.elemwise_binary">) {
188+
%arg2: !transform.op<"linalg.elementwise">) {
189189
// The actual tiling transformation takes tile sizes as attributes.
190190
%loop, %tiled = transform.structured.tile_using_forall %arg1 tile_sizes [4, 32]
191191
: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)
@@ -219,7 +219,7 @@ module attributes {transform.with_named_sequence} {
219219
transform.named_sequence @__transform_main
220220
%arg0: !transform.any_op,
221221
%arg1: !transform.op<"linalg.matmul">,
222-
%arg2: !transform.op<"linalg.elemwise_binary">) {
222+
%arg2: !transform.op<"linalg.elementwise">) {
223223
// We can cast one type to another as long as operations are compatible
224224
// with both types. This creates "aliasing" handles.
225225
%casted = transform.cast %arg1 : !transform.op<"linalg.matmul">
@@ -248,7 +248,7 @@ sequence.mlir:28:3: error: op uses a handle invalidated by a previously executed
248248
transform.debug.emit_remark_at %matmul, "elemwise_binaries" : !transform.op<"linalg.matmul">
249249
^
250250
sequence.mlir:21:29: note: handle to invalidated ops
251-
^bb0(%root: !transform.any_op, %matmul: !transform.op<"linalg.matmul">, %elemwise: !transform.op<"linalg.elemwise_binary">):
251+
^bb0(%root: !transform.any_op, %matmul: !transform.op<"linalg.matmul">, %elemwise: !transform.op<"linalg.elementwise">):
252252
^
253253
sequence.mlir:27:19: note: invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them
254254
%loop, %tiled = transform.structured.tile_using_forall %mm tile_sizes [4, 32]
@@ -263,12 +263,12 @@ module attributes {transform.with_named_sequence} {
263263
transform.named_sequence @__transform_main(
264264
%arg0: !transform.any_op,
265265
%arg1: !transform.op<"linalg.matmul">,
266-
%arg2: !transform.op<"linalg.elemwise_binary">) {
266+
%arg2: !transform.op<"linalg.elementwise">) {
267267
// Since the %arg2 handle is associated with both elementwise operations,
268268
// we need to split it into two handles so we can target only the second
269269
// elementwise operation.
270270
%add, %max = transform.split_handle %arg2
271-
: (!transform.op<"linalg.elemwise_binary">)
271+
: (!transform.op<"linalg.elementwise">)
272272
-> (!transform.any_op, !transform.any_op)
273273
274274
// The actual tiling transformation takes tile sizes as attributes. It
@@ -308,12 +308,12 @@ module attributes {transform.with_named_sequence} {
308308
transform.named_sequence @__transform_main(
309309
%arg0: !transform.any_op,
310310
%arg1: !transform.op<"linalg.matmul">,
311-
%arg2: !transform.op<"linalg.elemwise_binary">) {
311+
%arg2: !transform.op<"linalg.elementwise">) {
312312
// Since the %arg2 handle is associated with both elementwise operations,
313313
// we need to split it into two handles so we can target only the second
314314
// elementwise operation.
315315
%add, %max = transform.split_handle %arg2
316-
: (!transform.op<"linalg.elemwise_binary">)
316+
: (!transform.op<"linalg.elementwise">)
317317
-> (!transform.any_op, !transform.any_op)
318318
319319
// The actual tiling transformation takes tile sizes as attributes. It
@@ -384,7 +384,7 @@ test/Examples/transform/Ch1/invalidation-2.mlir:106:18: note: invalidated by thi
384384
%func, %call = transform.loop.outline %outline_target {func_name = "outlined"}
385385
^
386386
test/Examples/transform/Ch1/invalidation-2.mlir:24:13: note: ancestor payload op
387-
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
387+
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
388388
^
389389
test/Examples/transform/Ch1/invalidation-2.mlir:24:13: note: nested payload op
390390
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)

mlir/docs/Tutorials/transform/Ch2.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,12 @@ module attributes {transform.with_named_sequence} {
290290
transform.named_sequence @__transform_main(
291291
%arg0: !transform.any_op,
292292
%arg1: !transform.op<"linalg.matmul">,
293-
%arg2: !transform.op<"linalg.elemwise_binary">) {
293+
%arg2: !transform.op<"linalg.elementwise">) {
294294
// Since the %arg2 handle is associated with both elementwise operations,
295295
// we need to split it into two handles so we can target only the second
296296
// elementwise operation.
297297
%add, %max = transform.split_handle %arg2
298-
: (!transform.op<"linalg.elemwise_binary">)
298+
: (!transform.op<"linalg.elementwise">)
299299
-> (!transform.any_op, !transform.any_op)
300300
301301
// The actual tiling transformation takes tile sizes as attributes. It

mlir/docs/Tutorials/transform/Ch4.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
4242
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
4343
4444
// Elementwise addition.
45-
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
45+
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
4646
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
4747
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
4848
4949
// Elementwise max with 0 (ReLU).
5050
%c0f = arith.constant 0.0 : f32
51-
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
51+
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
5252
ins(%biased, %c0f : tensor<512x512xf32>, f32)
5353
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
5454
func.return %relued : tensor<512x512xf32>
@@ -59,7 +59,7 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
5959

6060
In Chapter 1, we were calling the test transform interpreter pass with
6161
additional arguments, `bind-first-extra-to-ops=linalg.matmul
62-
bind-second-extra-to-ops=linalg.elemwise_binary`, to provide initial
62+
bind-second-extra-to-ops=linalg.elementwise`, to provide initial
6363
associations for operation handles. Instead, we can use match operations to
6464
discover relevant operations in the payload IR. Match operations can be combined
6565
with “regular” transform operations using, e.g., the
@@ -97,7 +97,7 @@ module @transforms attributes { transform.with_named_sequence } {
9797
// rewriter sequence on success.
9898
transform.named_sequence @match_elemwise(
9999
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
100-
transform.match.operation_name %entry ["linalg.elemwise_binary"]
100+
transform.match.operation_name %entry ["linalg.elementwise"]
101101
: !transform.any_op
102102
transform.yield %entry : !transform.any_op
103103
}
@@ -127,7 +127,7 @@ module @transforms attributes { transform.with_named_sequence } {
127127
This script can be executed using the non-test interpreter pass running on the
128128
root operation of the translation unit without additional flags: `mlir-opt
129129
--transform-interpreter`. It will emit corresponding remarks at
130-
`linalg.elemwise_binary` and `linalg.matmul` operations. In debug builds, the
130+
`linalg.elementwise` and `linalg.matmul` operations. In debug builds, the
131131
infrastructure provides a convenient method to understand the matching process
132132
by passing `-debug-only=transform-matcher` to `mlir-opt` or a derived tool. It
133133
will print the silenceable failure messages produced by the match operations
@@ -169,7 +169,7 @@ transform.named_sequence @match_matmul_elemwise(
169169
%last: !transform.any_op {transform.readonly})
170170
-> (!transform.any_op, !transform.any_op, !transform.any_op) {
171171
// The last operation must be an elementwise binary.
172-
transform.match.operation_name %last ["linalg.elemwise_binary"]
172+
transform.match.operation_name %last ["linalg.elementwise"]
173173
: !transform.any_op
174174
// Its first operand must be defined by another operation, to which we
175175
// will get a handle here. We are guaranteed that the first operand exists
@@ -179,7 +179,7 @@ transform.named_sequence @match_matmul_elemwise(
179179
%middle = transform.get_producer_of_operand %last[0]
180180
: (!transform.any_op) -> !transform.any_op
181181
// The defining operation must itself be an elementwise binary.
182-
transform.match.operation_name %middle ["linalg.elemwise_binary"]
182+
transform.match.operation_name %middle ["linalg.elementwise"]
183183
: !transform.any_op
184184
// And the first operand of that operation must be defined by yet another
185185
// operation.
@@ -399,7 +399,7 @@ transform.named_sequence @match_matmul_elemwise(
399399
-> (!transform.any_op, !transform.any_op, !transform.any_op,
400400
!transform.param<i32>) {
401401
// The last operation must be an elementwise binary.
402-
transform.match.operation_name %last ["linalg.elemwise_binary"]
402+
transform.match.operation_name %last ["linalg.elementwise"]
403403
: !transform.any_op
404404
405405
// One of its operands must be defined by another operation, to which we
@@ -413,7 +413,7 @@ transform.named_sequence @match_matmul_elemwise(
413413
%def = transform.get_defining_op %operand
414414
: (!transform.any_value) -> !transform.any_op
415415
// The defining operation must itself be an elementwise binary.
416-
transform.match.operation_name %def ["linalg.elemwise_binary"]
416+
transform.match.operation_name %def ["linalg.elementwise"]
417417
: !transform.any_op
418418
transform.yield %def : !transform.any_op
419419
}

mlir/docs/Tutorials/transform/ChH.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ scf.forall (%co) in (2) {
290290
scf.forall (%n, %y, %xo) in (5, 80, 20) {
291291
tensor.extract_slice
292292
// Implicit dimensions [ni=0:1, y=0:1, xi=0:5, ci=0:64]
293-
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } // ...
293+
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed> // ...
294294
scf.forall.in_parallel {
295295
tensor.parallel_insert_slice // ...
296296
}

mlir/test/Examples/transform/Ch1/invalidation-1.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: mlir-opt %s \
22
// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \
3-
// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\
3+
// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elementwise},\
44
// RUN: canonicalize,cse,symbol-dce)" \
55
// RUN: --split-input-file --verify-diagnostics
66

@@ -16,7 +16,7 @@ module attributes {transform.with_named_sequence} {
1616
%arg0: !transform.any_op,
1717
// expected-note @below {{handle to invalidated ops}}
1818
%arg1: !transform.op<"linalg.matmul">,
19-
%arg2: !transform.op<"linalg.elemwise_binary">) {
19+
%arg2: !transform.op<"linalg.elementwise">) {
2020
// The actual tiling transformation takes tile sizes as attributes.
2121
// expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}}
2222
%tiled, %loop = transform.structured.tile_using_forall %arg1 tile_sizes [4, 32]
@@ -39,14 +39,14 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
3939
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
4040

4141
// Elementwise addition.
42-
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
42+
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
4343
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
4444
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
4545

4646
// Elementwise max with 0 (ReLU).
47-
%c0f = arith.constant 0.0 : f32
48-
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
49-
ins(%biased, %c0f : tensor<512x512xf32>, f32)
47+
%c0f = arith.constant dense<0.0> : tensor<512x512xf32>
48+
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
49+
ins(%biased, %c0f : tensor<512x512xf32>, tensor<512x512xf32>)
5050
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
5151
func.return %relued : tensor<512x512xf32>
5252
}
@@ -57,7 +57,7 @@ module attributes {transform.with_named_sequence} {
5757
transform.named_sequence @__transform_main(
5858
%arg0: !transform.any_op,
5959
%arg1: !transform.op<"linalg.matmul">,
60-
%arg2: !transform.op<"linalg.elemwise_binary">) {
60+
%arg2: !transform.op<"linalg.elementwise">) {
6161
// We can cast one type to another as long as operations are compatible
6262
// with both types. This creates "aliasing" handles.
6363
// expected-note @below {{handle to invalidated ops}}
@@ -88,14 +88,14 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
8888
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
8989

9090
// Elementwise addition.
91-
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
91+
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
9292
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
9393
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
9494

9595
// Elementwise max with 0 (ReLU).
96-
%c0f = arith.constant 0.0 : f32
97-
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
98-
ins(%biased, %c0f : tensor<512x512xf32>, f32)
96+
%c0f = arith.constant dense<0.0> : tensor<512x512xf32>
97+
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
98+
ins(%biased, %c0f : tensor<512x512xf32>, tensor<512x512xf32>)
9999
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
100100
func.return %relued : tensor<512x512xf32>
101101
}

0 commit comments

Comments
 (0)