@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
161161
162162// -----
163163
164+ // In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
165+ // 1 -> 0
166+ // 2 -> 4
167+ // Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
168+ // (same as the example above, but one of the dims is scalable)
169+ // CHECK-LABEL: @shape_cast_of_transpose_scalable
170+ // CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
171+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
172+ // CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
173+ // CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8>
174+ func.func @shape_cast_of_transpose_scalable (%arg : vector <1 x[4 ]x4 x1 x1 xi8 >) -> vector <[4 ]x4 xi8 > {
175+ %0 = vector.transpose %arg , [1 , 0 , 3 , 4 , 2 ]
176+ : vector <1 x[4 ]x4 x1 x1 xi8 > to vector <[4 ]x1 x1 x1 x4 xi8 >
177+ %1 = vector.shape_cast %0 : vector <[4 ]x1 x1 x1 x4 xi8 > to vector <[4 ]x4 xi8 >
178+ return %1 : vector <[4 ]x4 xi8 >
179+ }
180+
181+ // -----
182+
164183// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
165184// 1 -> 2
166185// 2 -> 1
@@ -180,36 +199,10 @@ func.func @negative_shape_cast_of_transpose(%arg : vector<1x4x4x1xi8>) -> vector
180199
181200// -----
182201
183- // Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
184- // scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
185- // CHECK-LABEL: @negative_shape_cast_of_transpose_scalable
186- // CHECK: vector.transpose
187- // CHECK: vector.shape_cast
188- func.func @negative_shape_cast_of_transpose_scalable (%arg : vector <[4 ]x1 xi8 >) -> vector <[4 ]xi8 > {
189- %0 = vector.transpose %arg , [1 , 0 ] : vector <[4 ]x1 xi8 > to vector <1 x[4 ]xi8 >
190- %1 = vector.shape_cast %0 : vector <1 x[4 ]xi8 > to vector <[4 ]xi8 >
191- return %1 : vector <[4 ]xi8 >
192- }
193-
194- // -----
195-
196202///===----------------------------------------------------------------------===//
197203/// FoldTransposeShapeCast: transpose(shape_cast) -> shape_cast
198204///===----------------------------------------------------------------------===//
199205
200- // The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable
201- // vectors.
202- // CHECK-LABEL: @transpose_of_shape_cast_scalable
203- // CHECK: vector.shape_cast
204- // CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8>
205- func.func @transpose_of_shape_cast_scalable (%arg : vector <[4 ]xi8 >) -> vector <[4 ]x1 xi8 > {
206- %0 = vector.shape_cast %arg : vector <[4 ]xi8 > to vector <1 x[4 ]xi8 >
207- %1 = vector.transpose %0 , [1 , 0 ] : vector <1 x[4 ]xi8 > to vector <[4 ]x1 xi8 >
208- return %1 : vector <[4 ]x1 xi8 >
209- }
210-
211- // -----
212-
213206// A transpose that is 'order preserving' can be treated like a shape_cast.
214207// CHECK-LABEL: @transpose_of_shape_cast
215208// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
@@ -225,11 +218,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi
225218
226219// -----
227220
228- // Scalable dimensions should be treated as non-unit dimensions.
229- // CHECK-LABEL: @transpose_of_shape_cast_scalable
221+ // CHECK-LABEL: @transpose_shape_cast_scalable
222+ // CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
223+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
224+ // CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
225+ // CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
226+ func.func @shape_cast_transpose_scalable (%arg : vector <[2 ]x3 x1 x1 xi8 >) -> vector <[6 ]x1 x1 xi8 > {
227+ %0 = vector.shape_cast %arg : vector <[2 ]x3 x1 x1 xi8 > to vector <[6 ]x1 x1 xi8 >
228+ %1 = vector.transpose %0 , [0 , 2 , 1 ]
229+ : vector <[6 ]x1 x1 xi8 > to vector <[6 ]x1 x1 xi8 >
230+ return %1 : vector <[6 ]x1 x1 xi8 >
231+ }
232+
233+ // -----
234+
235+ // Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
236+ // (hence no folding).
237+ // CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit
230238// CHECK: vector.shape_cast
231239// CHECK: vector.transpose
232- func.func @transpose_of_shape_cast_scalable_unit (%arg : vector <[1 ]x4 x1 xi8 >) -> vector <4 x[1 ]xi8 > {
240+ func.func @negative_shape_cast_transpose_scalable_unit (%arg : vector <[1 ]x4 x1 xi8 >) -> vector <4 x[1 ]xi8 > {
233241 %0 = vector.shape_cast %arg : vector <[1 ]x4 x1 xi8 > to vector <[1 ]x4 xi8 >
234242 %1 = vector.transpose %0 , [1 , 0 ] : vector <[1 ]x4 xi8 > to vector <4 x[1 ]xi8 >
235243 return %1 : vector <4 x[1 ]xi8 >
0 commit comments