1- // DEFINE: %{compile} = mlir-opt %s \
2- // DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
3- // DEFINE: --lower-vector-mask |\
4- // DEFINE: mlir-opt -arm-sve-legalize-vector-storage -convert-vector-to-llvm="enable-arm-sve"\
5- // DEFINE: -test-lower-to-llvm -o %t
1+ // DEFINE: %{td_entry_point} =
2+
3+ // DEFINE: %{compile} = mlir-opt %s \
4+ // DEFINE: -transform-preload-library='transform-library-paths=%p/td/pack-unpack.mlir' \
5+ // DEFINE: -transform-interpreter=entry-point=%{td_entry_point} \
6+ // DEFINE: -lower-vector-mask -convert-vector-to-scf="full-unroll target-rank=0" \
7+ // DEFINE: -arm-sve-legalize-vector-storage -convert-vector-to-llvm="enable-arm-sve"\
8+ // DEFINE: -test-lower-to-llvm -o %t
69// DEFINE: %{entry_point} = main
710// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
811// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
912
13+ /// Run _without_ vectorization
14+ // REDEFINE: %{td_entry_point} = __transform_main_basic
1015// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
1116
12- /// End-to-end test for linalg.pack + linalg.unpack where one of the inner tile sizes is
13- /// scalable.
14- /// NOTE: Vectorization has not been enabled yet!
15-
17+ /// Run _with_ vectorization
18+ // REDEFINE: %{td_entry_point} = __transform_main_vectorized
19+ // RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
1620
17- /// The main entry point
21+ //===----------------------------------------------------------------------===//
22+ /// HIGH-LEVEL OVERVIEW
23+ ///
24+ /// End-to-end test for linalg.pack + linalg.unpack where one of the inner tile
25+ /// sizes is scalable.
26+ ///
27+ /// Two versions of the transform IR are tested:
28+ /// * without vectorization (see @__transform_main_basic in pack-unpack.mlir)
29+ /// * with vectorization (see @__transform_main_vectorized in pack-unpack.mlir)
30+ ///
31+ /// With the payload IR fixed, the runtime output is identical. Note - in both
32+ /// cases the tile sizes are scalable.
33+ ///
34+ /// TODO: ATM only linalg.unpack is vectorized. Add linalg.pack vectorization.
35+ //===----------------------------------------------------------------------===//
36+
37+ //===----------------------------------------------------------------------===//
38+ // @main
39+ //
40+ // Thin wrapper over the main test function to allow changing the runtime
41+ // vector length via @setArmVLBits (calling setArmVLBits() in a function that
42+ // uses SVE vectors is UB).
43+ //===----------------------------------------------------------------------===//
1844func.func @main () {
1945 // Set vscale to 2 (vector width = 256). This will have identical effect to:
2046 // * qemu-aarch64 -cpu max,sve-max-vq=2 (...)
2147 // (If your platform supports it, you can play with other values as well)
2248 %c256 = arith.constant 256 : i32
2349 func.call @setArmVLBits (%c256 ) : (i32 ) -> ()
24- func.call @test_pack_unpack_scalable_inner_tile () : () -> ()
50+ func.call @pack_unpack_with_scalable_inner_tile () : () -> ()
2551
2652 return
2753}
2854
29- func.func @test_pack_unpack_scalable_inner_tile () attributes {no_inline } {
55+ //===----------------------------------------------------------------------===//
56+ // @pack_unpack_with_scalable_inner_tile
57+ //
58+ // The main test function that initilaises the matrices an calls pack/unpack
59+ // hooks.
60+ //===----------------------------------------------------------------------===//
61+ func.func @pack_unpack_with_scalable_inner_tile () attributes {no_inline } {
3062 // Dynamic/scalable tile size (vscale x 4)
3163 %c4 = arith.constant 4 : index
3264 %vs = vector.vscale
@@ -95,7 +127,11 @@ func.func @test_pack_unpack_scalable_inner_tile() attributes {no_inline} {
95127 return
96128}
97129
98- /// Takes the unpacked matrix + inner tile size to use and return the packed matrix.
130+ //===----------------------------------------------------------------------===//
131+ // @pack_main
132+ //
133+ // Takes the unpacked matrix + inner tile size to use and return the packed matrix.
134+ //===----------------------------------------------------------------------===//
99135func.func private @pack_main (%A: tensor <7 x12 xi32 >, %inner_tile_size: index ) -> (tensor <2 x?x4 x?xi32 >) {
100136 // Get the size of dim (we could skip tensor.dim, but this way we can keep it generic)
101137 %c1 = arith.constant 1 : index
@@ -122,7 +158,11 @@ func.func private @pack_main(%A: tensor<7x12xi32>, %inner_tile_size: index) -> (
122158 return %A_pack : tensor <2 x?x4 x?xi32 >
123159}
124160
161+ //===----------------------------------------------------------------------===//
162+ // @unpack_main
163+ //
125164/// Takes the packed matrix, unpacks it and returns the result.
165+ //===----------------------------------------------------------------------===//
126166func.func private @unpack_main (%A_pack : tensor <2 x?x4 x?xi32 >, %inner_tile_size: index ) -> tensor <7 x12 xi32 > {
127167 %A_unpack_empty = tensor.empty () : tensor <7 x12 xi32 >
128168
@@ -134,57 +174,5 @@ func.func private @unpack_main(%A_pack : tensor<2x?x4x?xi32>, %inner_tile_size:
134174 return %A_unpack : tensor <7 x12 xi32 >
135175}
136176
137- module @transforms attributes { transform.with_named_sequence } {
138- transform.named_sequence @__transform_main (%module: !transform.any_op {transform.consume }) {
139- %pack = transform.structured.match ops {[" linalg.pack" ]} in %module : (!transform.any_op ) -> !transform.any_op
140- %unpack = transform.structured.match ops {[" linalg.unpack" ]} in %module : (!transform.any_op ) -> !transform.any_op
141-
142- // 1.1 Tile the linalg.pack Op so that we can decompose it into e.g. tensor.pad
143- // and other lower-level Ops (see step 2.1)
144- %tiled_pack_op_p , %loops_pack:2 = transform.structured.tile_using_for %pack tile_sizes [1 , 1 ]
145- : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op )
146-
147- // 1.2 Tile the linalg.unpack Op so that we can decompose it into e.g. tensor.pad
148- // and other lower-level Ops (see step 2)
149- %tiled_unpack_op_p , %loops_unpack:2 = transform.structured.tile_using_for %unpack tile_sizes [4 , 1 ]
150- : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op )
151-
152- // 2.1. Decompose tiled PackOp into lower-level Ops
153- %func_op_pack = transform.get_parent_op %tiled_pack_op_p {isolated_from_above } : (!transform.any_op ) -> !transform.op <" func.func" >
154- transform.apply_patterns to %func_op_pack {
155- transform.apply_patterns.linalg.decompose_pack_unpack
156- transform.apply_patterns.linalg.decompose_pad
157- } : !transform.op <" func.func" >
158-
159- transform.apply_patterns to %func_op_pack {
160- transform.apply_patterns.tensor.fold_tensor_subset_ops
161- transform.apply_patterns.canonicalization
162- } : !transform.op <" func.func" >
163-
164- // 2.1. Decompose tiled UnpackOp into lower-level Ops
165- %func_op_unpack = transform.get_parent_op %tiled_unpack_op_p {isolated_from_above } : (!transform.any_op ) -> !transform.op <" func.func" >
166- transform.apply_patterns to %func_op_unpack {
167- transform.apply_patterns.linalg.decompose_pack_unpack
168- } : !transform.op <" func.func" >
169-
170- transform.apply_patterns to %func_op_unpack {
171- transform.apply_patterns.tensor.fold_tensor_subset_ops
172- transform.apply_patterns.canonicalization
173- } : !transform.op <" func.func" >
174-
175- // 3. Bufferize before lowering to LLVM
176- %bufferize = transform.bufferization.one_shot_bufferize %module
177- {bufferize_function_boundaries =true } : (!transform.any_op ) -> !transform.any_op
178-
179- // 4. Canonicalize
180- %func_op_bufferized = transform.structured.match ops {[" func.func" ]} in %bufferize : (!transform.any_op ) -> !transform.op <" func.func" >
181- transform.apply_patterns to %func_op_bufferized {
182- transform.apply_patterns.canonicalization
183- } : !transform.op <" func.func" >
184-
185- transform.yield
186- }
187- }
188-
189177func.func private @printMemrefI32 (%ptr : tensor <*xi32 >)
190178func.func private @setArmVLBits (%bits : i32 )
0 commit comments