11// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s
22
3+
4+ // Test the `LegalizeTransferRead` pattern
5+ // (mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp)
6+
37// -----
48
9+ // This is the base case, unremarkable in any way, except that it's our main
10+ // motivating example and use case.
11+
512// CHECK-LABEL: @test_base_case
613// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
714// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
@@ -23,82 +30,35 @@ func.func @test_base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> ve
2330
2431// -----
2532
26- // CHECK-LABEL: @test_using_strided_layout
27- // CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
28- // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
29- // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
30- // CHECK-SAME: : memref<?x?x?x8xi8, strided<[?, ?, 8, 1]>> into
31- // CHECK-SAME: memref<?x?x?xi8, strided<[?, ?, 1]>>
32- // CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
33- // CHECK-SAME: : memref<?x?x?xi8, strided<[?, ?, 1]>>, vector<[32]xi8>
34- // CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
35- // CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
36-
37- #s0 = strided <[?, ?, 8 , 1 ]>
38-
39- func.func @test_using_strided_layout (%i : index , %j : index , %M : memref <?x?x?x8 xi8 , #s0 >) -> vector <[4 ]x8 xi8 > {
40- %c0 = arith.constant 0 : index
41- %c0_i8 = arith.constant 0 : i8
42-
43- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x?x8 xi8 , #s0 >, vector <[4 ]x8 xi8 >
44-
45- return %A : vector <[4 ]x8 xi8 >
46- }
47-
48- // -----
33+ // Test the case where the scalable dimension is not the second-to-last.
4934
5035// CHECK-LABEL: @test_3d_vector
5136// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
5237// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
5338// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
54- // CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
55- // CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
39+ // CHECK-SAME: : memref<?x?x2x8xi8> into memref<?x?xi8>
5640// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
57- // CHECK-SAME: : memref<?x?xi8, strided<[?, 1]> >, vector<[64]xi8>
41+ // CHECK-SAME: : memref<?x?xi8>, vector<[64]xi8>
5842// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8>
5943// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
6044
61- #s1 = strided <[?, 16 , 8 , 1 ]>
62-
63- func.func @test_3d_vector (%i : index , %j : index , %M : memref <?x?x2 x8 xi8 , #s1 >) -> vector <[4 ]x2 x8 xi8 > {
45+ func.func @test_3d_vector (%i : index , %j : index , %M : memref <?x?x2 x8 xi8 >) -> vector <[4 ]x2 x8 xi8 > {
6446 %c0 = arith.constant 0 : index
6547 %c0_i8 = arith.constant 0 : i8
6648
67- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true , true ]} : memref <?x?x2 x8 xi8 , #s1 >, vector <[4 ]x2 x8 xi8 >
49+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true , true ]} : memref <?x?x2 x8 xi8 >, vector <[4 ]x2 x8 xi8 >
6850
6951 return %A : vector <[4 ]x2 x8 xi8 >
7052}
7153
7254// -----
7355
74- // CHECK-LABEL: @test_4d_vector
75- // CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
76- // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
77- // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
78- // CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
79- // CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
80- // CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
81- // CHECK-SAME: : memref<?x?xi8, strided<[?, 1]>>, vector<2x[64]xi8>
82- // CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
83- // CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
84-
85- #s2 = strided <[?, 16 , 8 , 1 ]>
86-
87- func.func @test_4d_vector (%i : index , %j : index , %M : memref <?x?x2 x8 xi8 , #s2 >) -> vector <2 x[4 ]x2 x8 xi8 > {
88- %c0 = arith.constant 0 : index
89- %c0_i8 = arith.constant 0 : i8
90-
91- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [false , true , true , true ]} : memref <?x?x2 x8 xi8 , #s2 >, vector <2 x[4 ]x2 x8 xi8 >
92-
93- return %A : vector <2 x[4 ]x2 x8 xi8 >
94- }
95-
96- // -----
56+ // Test the case when the vector is already LLVM-legal (fixed).
9757
98- // CHECK-LABEL: @negative_test_vector_legal_non_scalable
58+ // CHECK-LABEL: @negative_test_vector_legal_fixed
9959// CHECK-NOT: memref.collapse
10060
101- func.func @negative_test_vector_legal_non_scalable (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <8 x8 xi8 > {
61+ func.func @negative_test_vector_legal_fixed (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <8 x8 xi8 > {
10262 %c0 = arith.constant 0 : index
10363 %c0_i8 = arith.constant 0 : i8
10464
@@ -109,10 +69,12 @@ func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M :
10969
11070// -----
11171
112- // CHECK-LABEL: @negative_test_vector_legal_scalable_0
72+ // Test the case when the vector is already LLVM-legal (single-dimension scalable).
73+
74+ // CHECK-LABEL: @negative_test_vector_legal_1d_scalable
11375// CHECK-NOT: memref.collapse
11476
115- func.func @negative_test_vector_legal_scalable_0 (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[8 ]xi8 > {
77+ func.func @negative_test_vector_legal_1d_scalable (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[8 ]xi8 > {
11678 %c0 = arith.constant 0 : index
11779 %c0_i8 = arith.constant 0 : i8
11880
@@ -123,10 +85,13 @@ func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : me
12385
12486// -----
12587
126- // CHECK-LABEL: @negative_test_vector_legal_scalable_1
88+ // Test the case when the vector is already LLVM-legal (single trailing
89+ // scalable dimension).
90+
91+ // CHECK-LABEL: @negative_test_vector_legal_trailing_scalable_dim
12792// CHECK-NOT: memref.collapse
12893
129- func.func @negative_test_vector_legal_scalable_1 (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <8 x[8 ]xi8 > {
94+ func.func @negative_test_vector_legal_trailing_scalable_dim (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <8 x[8 ]xi8 > {
13095 %c0 = arith.constant 0 : index
13196 %c0_i8 = arith.constant 0 : i8
13297
@@ -137,6 +102,8 @@ func.func @negative_test_vector_legal_scalable_1(%i : index, %j : index, %M : me
137102
138103// -----
139104
105+ // Test the case of unsupported vector type (more than one scalable dimension)
106+
140107// CHECK-LABEL: @negative_test_vector_type_not_supported
141108// CHECK-NOT: memref.collapse
142109
@@ -151,10 +118,14 @@ func.func @negative_test_vector_type_not_supported(%i : index, %j : index, %M :
151118
152119// -----
153120
154- // CHECK-LABEL: @negative_test_non_mem
121+ // Test the case of reading from a tensor - not supported, since the
122+ // transform reasons about memory layouts.
123+
124+ // CHECK-LABEL: @negative_test_tensor_transfer
125+
155126// CHECK-NOT: memref.collapse
156127
157- func.func @negative_test_non_mem (%i : index , %j : index , %M : tensor <?x?x?x8 xi8 >) -> vector <[4 ]x8 xi8 > {
128+ func.func @negative_test_tensor_transfer (%i : index , %j : index , %M : tensor <?x?x?x8 xi8 >) -> vector <[4 ]x8 xi8 > {
158129 %c0 = arith.constant 0 : index
159130 %c0_i8 = arith.constant 0 : i8
160131
@@ -165,98 +136,120 @@ func.func @negative_test_non_mem(%i : index, %j : index, %M : tensor<?x?x?x8xi8>
165136
166137// -----
167138
168- // CHECK-LABEL: @negative_test_discontig_mem_0
139+ // Test the case when the transfer is discontiguous because the memref
140+ // is discontiguous.
141+ // There are other ways to make a memref discontiguous. The transformation
142+ // is not concerned with the particular reason a memref is discontiguous, but
143+ // only with the fact. Therefore there are no variations with the memref made
144+ // discontiguous by some other mechanism.
145+
146+ // CHECK-LABEL: @negative_test_discontig_mem
169147// CHECK-NOT: memref.collapse
170148
171- #s3 = strided <[?, ?, 16 , 1 ]>
149+ #strides = strided <[?, ?, 16 , 1 ]>
172150
173- func.func @negative_test_discontig_mem_0 (%i : index , %j : index , %M : memref <?x?x?x8 xi8 , #s3 >) -> vector <[4 ]x8 xi8 > {
151+ func.func @negative_test_discontig_mem (%i : index , %j : index , %M : memref <?x?x?x8 xi8 , #strides >) -> vector <[4 ]x8 xi8 > {
174152 %c0 = arith.constant 0 : index
175153 %c0_i8 = arith.constant 0 : i8
176154
177- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x?x8 xi8 , #s3 >, vector <[4 ]x8 xi8 >
155+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x?x8 xi8 , #strides >, vector <[4 ]x8 xi8 >
178156
179157 return %A : vector <[4 ]x8 xi8 >
180158}
181159
182160// -----
183161
184- // CHECK-LABEL: @negative_test_discontig_mem_1
162+ // Test the case when the transformation is not applied because of
163+ // a non-trivial permutation map (broadcast).
164+
165+ // CHECK-LABEL: @negative_test_broadcast
185166// CHECK-NOT: memref.collapse
186167
187- #layout = affine_map <(i , j , k , p ) -> (j , i , k , p )>
168+ #perm = affine_map <(i , j , k , p ) -> (k , 0 )>
188169
189- func.func @negative_test_discontig_mem_1 (%i : index , %j : index , %M : memref <?x?x?x8 xi8 , #layout >) -> vector <[4 ]x8 xi8 > {
170+ func.func @negative_test_broadcast (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[4 ]x8 xi8 > {
190171 %c0 = arith.constant 0 : index
191172 %c0_i8 = arith.constant 0 : i8
192173
193- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x?x8 xi8 , #layout >, vector <[4 ]x8 xi8 >
174+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {permutation_map = #perm , in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x8 xi8 >
194175
195176 return %A : vector <[4 ]x8 xi8 >
196177}
197178
198179// -----
199180
200- // CHECK-LABEL: @negative_test_discontig_read_strided_vec
181+ // Test the case of a masked read - not supported right now.
182+ // (see mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp)
183+
184+ // CHECK-LABEL: @negative_test_masked
201185// CHECK-NOT: memref.collapse
202186
203- func.func @negative_test_discontig_read_strided_vec (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[4 ]x4 xi8 > {
187+ func.func @negative_test_masked (
188+ %i : index , %j : index ,
189+ %M : memref <?x?x?x8 xi8 >, %mask : vector <[4 ]x8 xi1 >) -> vector <[4 ]x8 xi8 > {
190+
204191 %c0 = arith.constant 0 : index
205192 %c0_i8 = arith.constant 0 : i8
206193
207- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x?x8 xi8 >, vector <[4 ]x4 xi8 >
194+ %A = vector.mask %mask {
195+ vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x8 xi8 >
196+ } : vector <[4 ]x8 xi1 > -> vector <[4 ]x8 xi8 >
208197
209- return %A : vector <[4 ]x 4 x i8 >
198+ return %A : vector <[4 ]x 8 x i8 >
210199}
211200
212201// -----
213202
214- // CHECK-LABEL: @negative_test_bcast_transp
215- // CHECK-NOT: memref.collapse
203+ // Test case with a mask operand - not supported right now.
204+ // (see mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp)
216205
217- #perm = affine_map <(i , j , k , p ) -> (k , 0 )>
206+ // CHECK-LABEL: @negative_test_with_mask
207+ // CHECK-NOT: memref.collapse
218208
219- func.func @negative_test_bcast_transp (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[4 ]x8 xi8 > {
209+ func.func @negative_test_with_mask (
210+ %i : index , %j : index ,
211+ %M : memref <?x?x?x8 xi8 >, %mask : vector <[4 ]x8 xi1 >) -> vector <[4 ]x8 xi8 > {
212+
220213 %c0 = arith.constant 0 : index
221214 %c0_i8 = arith.constant 0 : i8
222215
223- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 { permutation_map = #perm , in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x8 xi8 >
216+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 , %mask { in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x8 xi8 >
224217
225218 return %A : vector <[4 ]x8 xi8 >
226219}
227220
228221// -----
229222
230- // CHECK-LABEL: @negative_test_vector_mask
223+ // Test the case when the dimensions to collapse (excluding the scalable one)
224+ // of the vector and the memref do not match (static non matching dimension).
225+
226+ // CHECK-LABEL: @negative_test_non_matching_dim_static
231227// CHECK-NOT: memref.collapse
232228
233- func.func @negative_test_vector_mask (
234- %i : index , %j : index ,
235- %M : memref <?x?x?x8 xi8 >, %mask : vector <[4 ]x8 xi1 >) -> vector <[4 ]x8 xi8 > {
229+ func.func @negative_test_non_matching_dim_static (%i : index , %j : index , %M : memref <?x?x?x8 xi8 >) -> vector <[4 ]x4 xi8 > {
236230
237231 %c0 = arith.constant 0 : index
238232 %c0_i8 = arith.constant 0 : i8
239233
240- %A = vector.mask %mask {
241- vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x8 xi8 >
242- } : vector <[4 ]x8 xi1 > -> vector <[4 ]x8 xi8 >
234+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ] } : memref <?x?x?x8 xi8 >, vector <[4 ]x4 xi8 >
243235
244- return %A : vector <[4 ]x 8 x i8 >
236+ return %A : vector <[4 ]x 4 x i8 >
245237}
246238
247239// -----
248240
249- // CHECK-LABEL: @negative_test_mask_operand
241+ // Test the case when the dimensions to collapse (excluding the scalable one)
242+ // of the vector and the memref do not match (dynamic non matching dimension).
243+
244+ // CHECK-LABEL: @negative_test_non_matching_dim_dynamic
250245// CHECK-NOT: memref.collapse
251246
252- func.func @negative_test_mask_operand (
253- %i : index , %j : index ,
254- %M : memref <?x?x?x8 xi8 >, %mask : vector <[4 ]x8 xi1 >) -> vector <[4 ]x8 xi8 > {
247+ func.func @negative_test_non_matching_dim_dynamic (%i : index , %j : index , %M : memref <?x?x?x?xi8 >) -> vector <[4 ]x4 xi8 > {
255248
256249 %c0 = arith.constant 0 : index
257250 %c0_i8 = arith.constant 0 : i8
258251
259- %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 , %mask {in_bounds = [true , true ] } : memref <?x?x?x 8 x i8 >, vector <[4 ]x 8 x i8 >
252+ %A = vector.transfer_read %M [%i , %j , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ] } : memref <?x?x?x?x i8 >, vector <[4 ]x 4 x i8 >
260253
261- return %A : vector <[4 ]x 8 x i8 >
254+ return %A : vector <[4 ]x 4 x i8 >
262255}
0 commit comments