@@ -26,6 +26,7 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
2626 // Elementwise max with 0 (ReLU).
2727 %c0f = arith.constant 0.0 : f32
2828 %relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
29+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>]
2930 ins(%biased, %c0f : tensor<512x512xf32>, f32)
3031 outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
3132 func.return %relued : tensor<512x512xf32>
@@ -95,18 +96,18 @@ $ mlir-opt sequence.mlir --pass-pipeline="
9596The ` sequence.mlir ` file contains _ both_ the payload IR function _ and_ the transform IR sequence nested in the same module. The transform interpreter pass will apply the ` @__transform_main ` named sequence to the anchor operation of the pass. In our case, we also asked the interpreter pass to associate the two extra arguments of the top-level sequence with all ` linalg.matmul ` and ` linalg.elementwise ` payload operations through the respective pass options. Running this pass results in the expected remarks:
9697
9798``` sh
98- sequence.mlir:7 :13: remark: matmul
99+ sequence.mlir:5 :13: remark: matmul
99100 %matmul = linalg.matmul ins(%lhs, %rhs: tensor< 512x512xf32> , tensor< 512x512xf32> )
100101 ^
101- sequence.mlir:7 :13: note: see current operation: %0 = linalg.matmul ins(%arg0, %arg1 : tensor< 512x512xf32> , tensor< 512x512xf32> ) outs(%arg3 : tensor< 512x512xf32> ) -> tensor< 512x512xf32>
102- sequence.mlir:10 :13: remark: elemwise_binaries
102+ sequence.mlir:5 :13: note: see current operation: %0 = linalg.matmul ins(%arg0, %arg1 : tensor< 512x512xf32> , tensor< 512x512xf32> ) outs(%arg3 : tensor< 512x512xf32> ) -> tensor< 512x512xf32>
103+ sequence.mlir:9 :13: remark: elemwise_binaries
103104 %biased = linalg.elementwise kind=# linalg.elementwise_kind<add>
104105 ^
105- sequence.mlir:10 :13: note: see current operation: %1 = linalg.elementwise kind=# linalg.elementwise_kind<add> > ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
106- sequence.mlir:14 :13: remark: elemwise_binaries
106+ sequence.mlir:9 :13: note: see current operation: %1 = linalg.elementwise kind=# linalg.elementwise_kind<add> ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
107+ sequence.mlir:15 :13: remark: elemwise_binaries
107108 %relued = linalg.elementwise kind=# linalg.elementwise_kind<max_signed>
108109 ^
109- sequence.mlir:14 :13: note: see current operation: %2 = linalg.elementwise kind=# linalg.elementwise_kind<max_signed>> ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
110+ sequence.mlir:15 :13: note: see current operation: %2 = linalg.elementwise kind=# linalg.elementwise_kind<max_signed> indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>] ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
110111```
111112
112113Note that ` %arg2 ` is associated with both elementwise payload operations. Any handle is associated with a list of entities. Individual transformations may or may not care about the order of elements in that list.
@@ -140,33 +141,39 @@ The transformation returns two handles, as indicated in its [documentation](http
140141Running this transformation with the same command as above expectedly produces the tiled code.
141142
142143``` mlir
144+ #map = affine_map<(d0) -> (d0 * 4)>
145+ #map1 = affine_map<(d0) -> (d0 * 32)>
146+ #map2 = affine_map<(d0, d1) -> (d0, d1)>
147+ #map3 = affine_map<(d0, d1) -> ()>
148+
143149func.func @fc_relu(%arg0: tensor<512x512xf32>,
144150 %arg1: tensor<512x512xf32>,
145151 %arg2: tensor<512x512xf32>,
146152 %arg3: tensor<512x512xf32>) -> tensor<512x512xf32> {
147- %cst = arith.constant 0.000000e+00 : f32
148153 %0 = scf.forall (%arg4, %arg5) in (128, 16) shared_outs(%arg6 = %arg3) -> (tensor<512x512xf32>) {
149- %3 = affine.apply affine_map<(d0) -> (d0 * 4)> (%arg4)
150- %4 = affine.apply affine_map<(d0) -> (d0 * 32)> (%arg5)
154+ %3 = affine.apply #map (%arg4)
155+ %4 = affine.apply #map1 (%arg5)
151156 %extracted_slice = tensor.extract_slice %arg0[%3, 0] [4, 512] [1, 1]
152157 : tensor<512x512xf32> to tensor<4x512xf32>
153158 %extracted_slice_0 = tensor.extract_slice %arg1[0, %4] [512, 32] [1, 1]
154- : tensor<512x512xf32> to tensor<512x32xf32>
159+ : tensor<512x512xf32> to tensor<512x32xf32>
155160 %extracted_slice_1 = tensor.extract_slice %arg6[%3, %4] [4, 32] [1, 1]
156- : tensor<512x512xf32> to tensor<4x32xf32>
161+ : tensor<512x512xf32> to tensor<4x32xf32>
157162 %5 = linalg.matmul
158163 ins(%extracted_slice, %extracted_slice_0
159- : tensor<4x512xf32>, tensor<512x32xf32>)
164+ : tensor<4x512xf32>, tensor<512x32xf32>)
160165 outs(%extracted_slice_1 : tensor<4x32xf32>) -> tensor<4x32xf32>
161166 scf.forall.in_parallel {
162167 tensor.parallel_insert_slice %5 into %arg6[%3, %4] [4, 32] [1, 1]
163- : tensor<4x32xf32> into tensor<512x512xf32>
168+ : tensor<4x32xf32> into tensor<512x512xf32>
164169 }
165170 }
166- %1 = linalg.elementwise kind=#linalg.elementwise_kind<add>>
167- ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>)
168- outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
169- %2 = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>>
171+ %1 = linalg.elementwise kind=#linalg.elementwise_kind<add>
172+ ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>)
173+ outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
174+ %cst = arith.constant 0.000000e+00 : f32
175+ %2 = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
176+ indexing_maps = [#map2, #map3, #map2]
170177 ins(%1, %cst : tensor<512x512xf32>, f32)
171178 outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
172179 return %2 : tensor<512x512xf32>
@@ -216,7 +223,7 @@ One may observe that some operations such as `transform.cast` do not consume the
216223
217224``` mlir
218225module attributes {transform.with_named_sequence} {
219- transform.named_sequence @__transform_main
226+ transform.named_sequence @__transform_main(
220227 %arg0: !transform.any_op,
221228 %arg1: !transform.op<"linalg.matmul">,
222229 %arg2: !transform.op<"linalg.elementwise">) {
0 commit comments