@@ -228,6 +228,16 @@ func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
228228
229229// -----
230230
231+ func.func @broadcast_vector_extsi_scalable (%a : vector <[4 ]xi8 >) -> vector <2 x[4 ]xi32 > {
232+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32>
233+ // CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32>
234+ %b = vector.broadcast %a : vector <[4 ]xi8 > to vector <2 x[4 ]xi8 >
235+ %r = arith.extsi %b : vector <2 x[4 ]xi8 > to vector <2 x[4 ]xi32 >
236+ return %r : vector <2 x[4 ]xi32 >
237+ }
238+
239+ // -----
240+
231241func.func @broadcast_scalar_extsi (%a : i8 ) -> vector <2 x4 xi32 > {
232242 // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
233243 // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
@@ -236,6 +246,16 @@ func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
236246 return %r : vector <2 x4 xi32 >
237247}
238248
249+ // -----
250+
251+ func.func @broadcast_scalar_extsi_scalable (%a : i8 ) -> vector <2 x[4 ]xi32 > {
252+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
253+ // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32>
254+ %b = vector.broadcast %a : i8 to vector <2 x[4 ]xi8 >
255+ %r = arith.extsi %b : vector <2 x[4 ]xi8 > to vector <2 x[4 ]xi32 >
256+ return %r : vector <2 x[4 ]xi32 >
257+ }
258+
239259//===----------------------------------------------------------------------===//
240260// [Pattern: ReorderElementwiseOpsOnTranspose]
241261//===----------------------------------------------------------------------===//
@@ -250,6 +270,16 @@ func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
250270
251271// -----
252272
273+ func.func @transpose_extsi_scalable (%a : vector <[4 ]x2 xi8 >) -> vector <2 x[4 ]xi32 > {
274+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]x2xi8> to vector<[4]x2xi32>
275+ // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<[4]x2xi32> to vector<2x[4]xi32>
276+ %b = vector.transpose %a , [1 , 0 ]: vector <[4 ]x2 xi8 > to vector <2 x[4 ]xi8 >
277+ %r = arith.extsi %b : vector <2 x[4 ]xi8 > to vector <2 x[4 ]xi32 >
278+ return %r : vector <2 x[4 ]xi32 >
279+ }
280+
281+ // -----
282+
253283// CHECK-LABEL: func @transpose_elementwise_same_type
254284// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
255285// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
@@ -265,6 +295,21 @@ func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2
265295
266296// -----
267297
298+ // CHECK-LABEL: func @transpose_elementwise_same_type_scalable
299+ // CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
300+ // CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x2xf32>
301+ // CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
302+ // CHECK: return %[[T]]
303+
304+ func.func @transpose_elementwise_same_type_scalable (%a : vector <[4 ]x2 xf32 >, %b : vector <[4 ]x2 xf32 >) -> vector <2 x[4 ]xf32 > {
305+ %at = vector.transpose %a , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
306+ %bt = vector.transpose %b , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
307+ %r = arith.addf %at , %bt : vector <2 x[4 ]xf32 >
308+ return %r : vector <2 x[4 ]xf32 >
309+ }
310+
311+ // -----
312+
268313// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
269314// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
270315// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
@@ -280,6 +325,21 @@ func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a :
280325
281326// -----
282327
328+ // CHECK-LABEL: func @transpose_elementwise_diff_operand_types_scalable
329+ // CHECK-SAME: (%[[COND:.+]]: vector<[4]x2xi1>, %[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
330+ // CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<[4]x2xi1>, vector<[4]x2xf32>
331+ // CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<[4]x2xf32> to vector<2x[4]xf32>
332+ // CHECK: return %[[T]]
333+ func.func @transpose_elementwise_diff_operand_types_scalable (%cond: vector <[4 ]x2 xi1 >, %a : vector <[4 ]x2 xf32 >, %b : vector <[4 ]x2 xf32 >) -> vector <2 x[4 ]xf32 > {
334+ %condt = vector.transpose %cond , [1 , 0 ]: vector <[4 ]x2 xi1 > to vector <2 x[4 ]xi1 >
335+ %at = vector.transpose %a , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
336+ %bt = vector.transpose %b , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
337+ %r = arith.select %condt , %at , %bt : vector <2 x[4 ]xi1 >, vector <2 x[4 ]xf32 >
338+ return %r : vector <2 x[4 ]xf32 >
339+ }
340+
341+ // -----
342+
283343// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
284344// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
285345// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
@@ -294,6 +354,20 @@ func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>,
294354
295355// -----
296356
357+ // CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type_scalable
358+ // CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
359+ // CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<[4]x2xf32>
360+ // CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<[4]x2xi1> to vector<2x[4]xi1>
361+ // CHECK: return %[[T]]
362+ func.func @transpose_elementwise_diff_operand_result_type_scalable (%a : vector <[4 ]x2 xf32 >, %b : vector <[4 ]x2 xf32 >) -> vector <2 x[4 ]xi1 > {
363+ %at = vector.transpose %a , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
364+ %bt = vector.transpose %b , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
365+ %r = arith.cmpf olt , %at , %bt : vector <2 x[4 ]xf32 >
366+ return %r : vector <2 x[4 ]xi1 >
367+ }
368+
369+ // -----
370+
297371// CHECK-LABEL: func @transpose_elementwise_splat_constant
298372// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
299373// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
@@ -310,6 +384,22 @@ func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vec
310384
311385// -----
312386
387+ // CHECK-LABEL: func @transpose_elementwise_splat_constant_scalable
388+ // CHECK-SAME: (%[[A:.+]]: vector<[4]x6x3x2xf32>)
389+ // CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<[4]x6x3x2xf32>
390+ // CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x6x3x2xf32>
391+ // CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
392+ // CHECK: return %[[T:.+]] : vector<6x[4]x2x3xf32>
393+
394+ func.func @transpose_elementwise_splat_constant_scalable (%a : vector <[4 ]x6 x3 x2 xf32 >) -> vector <6 x[4 ]x2 x3 xf32 > {
395+ %b = arith.constant dense <5.0 > : vector <6 x[4 ]x2 x3 xf32 >
396+ %at = vector.transpose %a , [1 , 0 , 3 , 2 ]: vector <[4 ]x6 x3 x2 xf32 > to vector <6 x[4 ]x2 x3 xf32 >
397+ %r = arith.addf %at , %b : vector <6 x[4 ]x2 x3 xf32 >
398+ return %r : vector <6 x[4 ]x2 x3 xf32 >
399+ }
400+
401+ // -----
402+
313403// CHECK-LABEL: func @transpose_elementwise_diff_map
314404// CHECK: vector.transpose
315405// CHECK: vector.transpose
@@ -320,3 +410,16 @@ func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6
320410 %r = arith.addf %at , %bt : vector <6 x4 x2 x3 xf32 >
321411 return %r : vector <6 x4 x2 x3 xf32 >
322412}
413+
414+ // -----
415+
416+ // CHECK-LABEL: func @transpose_elementwise_diff_map_scalable
417+ // CHECK: vector.transpose
418+ // CHECK: vector.transpose
419+ // CHECK: arith.addf
420+ func.func @transpose_elementwise_diff_map_scalable (%a : vector <[4 ]x6 x3 x2 xf32 >, %b: vector <6 x2 x[4 ]x3 xf32 >) -> vector <6 x[4 ]x2 x3 xf32 > {
421+ %at = vector.transpose %a , [1 , 0 , 3 , 2 ]: vector <[4 ]x6 x3 x2 xf32 > to vector <6 x[4 ]x2 x3 xf32 >
422+ %bt = vector.transpose %b , [0 , 2 , 1 , 3 ]: vector <6 x2 x[4 ]x3 xf32 > to vector <6 x[4 ]x2 x3 xf32 >
423+ %r = arith.addf %at , %bt : vector <6 x[4 ]x2 x3 xf32 >
424+ return %r : vector <6 x[4 ]x2 x3 xf32 >
425+ }
0 commit comments