1- // RUN: mlir-opt --mlir-disable-threading %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
1+ // RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
22
33// This file contains some tests of folding/canonicalizing vector.from_elements
44
77///===----------------------------------------------===//
88
99// CHECK-LABEL: func @extract_scalar_from_from_elements(
10- // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
10+ // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
1111func.func @extract_scalar_from_from_elements (%a: f32 , %b: f32 ) -> (f32 , f32 , f32 , f32 , f32 , f32 , f32 ) {
1212 // Extract from 0D.
1313 %0 = vector.from_elements %a : vector <f32 >
@@ -33,7 +33,7 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32
3333// -----
3434
3535// CHECK-LABEL: func @extract_1d_from_from_elements(
36- // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
36+ // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
3737func.func @extract_1d_from_from_elements (%a: f32 , %b: f32 ) -> (vector <3 xf32 >, vector <3 xf32 >) {
3838 %0 = vector.from_elements %a , %a , %a , %b , %b , %b : vector <2 x3 xf32 >
3939 // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32>
@@ -47,7 +47,7 @@ func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, ve
4747// -----
4848
4949// CHECK-LABEL: func @extract_2d_from_from_elements(
50- // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
50+ // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
5151func.func @extract_2d_from_from_elements (%a: f32 , %b: f32 ) -> (vector <2 x2 xf32 >, vector <2 x2 xf32 >) {
5252 %0 = vector.from_elements %a , %a , %a , %b , %b , %b , %b , %a , %b , %a , %a , %b : vector <3 x2 x2 xf32 >
5353 // CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32>
@@ -61,7 +61,7 @@ func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>,
6161// -----
6262
6363// CHECK-LABEL: func @from_elements_to_splat(
64- // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
64+ // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
6565func.func @from_elements_to_splat (%a: f32 , %b: f32 ) -> (vector <2 x3 xf32 >, vector <2 x3 xf32 >, vector <f32 >) {
6666 // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32>
6767 %0 = vector.from_elements %a , %a , %a , %a , %a , %a : vector <2 x3 xf32 >
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
8181
8282// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
8383// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>)
84- // CHECK: %[[SHAPE_CAST :.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8 >
85- // CHECK: return %[[SHAPE_CAST ]] : vector<2xi8>
84+ // CHECK: %[[EXTRACT :.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8 >
85+ // CHECK: return %[[EXTRACT ]] : vector<2xi8>
8686func.func @to_shape_cast_rank2_to_rank1 (%arg0: vector <1 x2 xi8 >) -> vector <2 xi8 > {
8787 %0 = vector.extract %arg0 [0 , 0 ] : i8 from vector <1 x2 xi8 >
8888 %1 = vector.extract %arg0 [0 , 1 ] : i8 from vector <1 x2 xi8 >
@@ -109,20 +109,13 @@ func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8>
109109 return %8 : vector <2 x2 x2 xi8 >
110110}
111111
112-
113112// -----
114113
115- // func.func @bar(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
116- // %0 = vector.extract %arg0[1] : vector<3x4xi8> from vector<2x3x4xi8>
117- // %1 = vector.shape_cast %0 : vector<3x4xi8> to vector<12xi8>
118- // return %1 : vector<12xi8>
119-
120114// CHECK-LABEL: func @source_larger_than_out(
121- // CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>)
122- // CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]] [1] : vector<3x4xi8> from vector<2x3x4xi8>
123- // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8>
124- // CHECK: return %[[SHAPE_CAST]] : vector<12xi8>
125-
115+ // CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>)
116+ // CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][1] : vector<3x4xi8> from vector<2x3x4xi8>
117+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8>
118+ // CHECK: return %[[SHAPE_CAST]] : vector<12xi8>
126119func.func @source_larger_than_out (%arg0: vector <2 x3 x4 xi8 >) -> vector <12 xi8 > {
127120 %0 = vector.extract %arg0 [1 , 0 , 0 ] : i8 from vector <2 x3 x4 xi8 >
128121 %1 = vector.extract %arg0 [1 , 0 , 1 ] : i8 from vector <2 x3 x4 xi8 >
@@ -140,13 +133,70 @@ func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
140133 return %12 : vector <12 xi8 >
141134}
142135
143- // TODO(newling) add more tests where the source is not the same size as out.
136+ // -----
137+
138+ // This test is similar to `source_larger_than_out` except here the number of elements
139+ // extracted contigously starting from the first position [0,0] could be 6 instead of 3
140+ // and the pattern would still match.
141+ // CHECK-LABEL: func @suffix_with_excess_zeros(
142+ // CHECK: %[[EXT:.*]] = vector.extract {{.*}}[0] : vector<3xi8> from vector<2x3xi8>
143+ // CHECK: return %[[EXT]] : vector<3xi8>
144+ func.func @suffix_with_excess_zeros (%arg0: vector <2 x3 xi8 >) -> vector <3 xi8 > {
145+ %0 = vector.extract %arg0 [0 , 0 ] : i8 from vector <2 x3 xi8 >
146+ %1 = vector.extract %arg0 [0 , 1 ] : i8 from vector <2 x3 xi8 >
147+ %2 = vector.extract %arg0 [0 , 2 ] : i8 from vector <2 x3 xi8 >
148+ %3 = vector.from_elements %0 , %1 , %2 : vector <3 xi8 >
149+ return %3 : vector <3 xi8 >
150+ }
151+
152+ // -----
153+
154+ // CHECK-LABEL: func @large_source_with_shape_cast_required(
155+ // CHECK-SAME: %[[A:.*]]: vector<2x2x2x2xi8>)
156+ // CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0, 1] : vector<2x2xi8> from vector<2x2x2x2xi8>
157+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<2x2xi8> to vector<1x4x1xi8>
158+ // CHECK: return %[[SHAPE_CAST]] : vector<1x4x1xi8>
159+ func.func @large_source_with_shape_cast_required (%arg0: vector <2 x2 x2 x2 xi8 >) -> vector <1 x4 x1 xi8 > {
160+ %0 = vector.extract %arg0 [0 , 1 , 0 , 0 ] : i8 from vector <2 x2 x2 x2 xi8 >
161+ %1 = vector.extract %arg0 [0 , 1 , 0 , 1 ] : i8 from vector <2 x2 x2 x2 xi8 >
162+ %2 = vector.extract %arg0 [0 , 1 , 1 , 0 ] : i8 from vector <2 x2 x2 x2 xi8 >
163+ %3 = vector.extract %arg0 [0 , 1 , 1 , 1 ] : i8 from vector <2 x2 x2 x2 xi8 >
164+ %4 = vector.from_elements %0 , %1 , %2 , %3 : vector <1 x4 x1 xi8 >
165+ return %4 : vector <1 x4 x1 xi8 >
166+ }
167+
168+ // -----
169+
170+ // Could match, but handled by `rewriteFromElementsAsSplat`.
171+ // CHECK-LABEL: func @extract_single_elm(
172+ // CHECK-NEXT: vector.extract
173+ // CHECK-NEXT: vector.splat
174+ // CHECK-NEXT: return
175+ func.func @extract_single_elm (%arg0 : vector <2 x3 xi8 >) -> vector <1 xi8 > {
176+ %0 = vector.extract %arg0 [0 , 0 ] : i8 from vector <2 x3 xi8 >
177+ %1 = vector.from_elements %0 : vector <1 xi8 >
178+ return %1 : vector <1 xi8 >
179+ }
180+
181+ // -----
182+
183+ // CHECK-LABEL: func @negative_source_contiguous_but_not_suffix(
184+ // CHECK-NOT: shape_cast
185+ // CHECK: from_elements
186+ func.func @negative_source_contiguous_but_not_suffix (%arg0: vector <2 x3 xi8 >) -> vector <3 xi8 > {
187+ %0 = vector.extract %arg0 [0 , 1 ] : i8 from vector <2 x3 xi8 >
188+ %1 = vector.extract %arg0 [0 , 2 ] : i8 from vector <2 x3 xi8 >
189+ %2 = vector.extract %arg0 [1 , 0 ] : i8 from vector <2 x3 xi8 >
190+ %3 = vector.from_elements %0 , %1 , %2 : vector <3 xi8 >
191+ return %3 : vector <3 xi8 >
192+ }
144193
145194// -----
146195
147196// The extracted elements are recombined into a single vector, but in a new order.
148197// CHECK-LABEL: func @negative_nonascending_order(
149- // CHECK-NOT: shape_cast
198+ // CHECK-NOT: shape_cast
199+ // CHECK: from_elements
150200func.func @negative_nonascending_order (%arg0: vector <1 x2 xi8 >) -> vector <2 xi8 > {
151201 %0 = vector.extract %arg0 [0 , 1 ] : i8 from vector <1 x2 xi8 >
152202 %1 = vector.extract %arg0 [0 , 0 ] : i8 from vector <1 x2 xi8 >
@@ -157,7 +207,8 @@ func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
157207// -----
158208
159209// CHECK-LABEL: func @negative_nonstatic_extract(
160- // CHECK-NOT: shape_cast
210+ // CHECK-NOT: shape_cast
211+ // CHECK: from_elements
161212func.func @negative_nonstatic_extract (%arg0: vector <1 x2 xi8 >, %i0 : index , %i1 : index ) -> vector <2 xi8 > {
162213 %0 = vector.extract %arg0 [0 , %i0 ] : i8 from vector <1 x2 xi8 >
163214 %1 = vector.extract %arg0 [0 , %i1 ] : i8 from vector <1 x2 xi8 >
@@ -168,7 +219,8 @@ func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 :
168219// -----
169220
170221// CHECK-LABEL: func @negative_different_sources(
171- // CHECK-NOT: shape_cast
222+ // CHECK-NOT: shape_cast
223+ // CHECK: from_elements
172224func.func @negative_different_sources (%arg0: vector <1 x2 xi8 >, %arg1: vector <1 x2 xi8 >) -> vector <2 xi8 > {
173225 %0 = vector.extract %arg0 [0 , 0 ] : i8 from vector <1 x2 xi8 >
174226 %1 = vector.extract %arg1 [0 , 1 ] : i8 from vector <1 x2 xi8 >
@@ -178,9 +230,10 @@ func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi
178230
179231// -----
180232
181- // CHECK-LABEL: func @negative_source_too_large(
182- // CHECK-NOT: shape_cast
183- func.func @negative_source_too_large (%arg0: vector <1 x3 xi8 >) -> vector <2 xi8 > {
233+ // CHECK-LABEL: func @negative_source_not_suffix(
234+ // CHECK-NOT: shape_cast
235+ // CHECK: from_elements
236+ func.func @negative_source_not_suffix (%arg0: vector <1 x3 xi8 >) -> vector <2 xi8 > {
184237 %0 = vector.extract %arg0 [0 , 0 ] : i8 from vector <1 x3 xi8 >
185238 %1 = vector.extract %arg0 [0 , 1 ] : i8 from vector <1 x3 xi8 >
186239 %2 = vector.from_elements %0 , %1 : vector <2 xi8 >
@@ -189,13 +242,27 @@ func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
189242
190243// -----
191244
192- // The inserted elements are are a subset of the extracted elements.
245+ // The inserted elements are a subset of the extracted elements.
193246// [0, 1, 2] -> [1, 1, 2]
194247// CHECK-LABEL: func @negative_nobijection_order(
195- // CHECK-NOT: shape_cast
248+ // CHECK-NOT: shape_cast
249+ // CHECK: from_elements
196250func.func @negative_nobijection_order (%arg0: vector <1 x3 xi8 >) -> vector <3 xi8 > {
197251 %0 = vector.extract %arg0 [0 , 1 ] : i8 from vector <1 x3 xi8 >
198252 %1 = vector.extract %arg0 [0 , 2 ] : i8 from vector <1 x3 xi8 >
199253 %2 = vector.from_elements %0 , %0 , %1 : vector <3 xi8 >
200254 return %2 : vector <3 xi8 >
201255}
256+
257+ // -----
258+
259+ // CHECK-LABEL: func @negative_source_too_small(
260+ // CHECK-NOT: shape_cast
261+ // CHECK: from_elements
262+ func.func @negative_source_too_small (%arg0: vector <2 xi8 >) -> vector <4 xi8 > {
263+ %0 = vector.extract %arg0 [0 ] : i8 from vector <2 xi8 >
264+ %1 = vector.extract %arg0 [1 ] : i8 from vector <2 xi8 >
265+ %2 = vector.from_elements %0 , %1 , %1 , %1 : vector <4 xi8 >
266+ return %2 : vector <4 xi8 >
267+ }
268+
0 commit comments