@@ -335,7 +335,7 @@ func.func @outerproduct_scalable(%arg0 : vector<[4]xf32>, %arg1 : vector<[8]xf32
335335
336336// CHECK-LABEL: @insert_strided_slice
337337func.func @insert_strided_slice (%a: vector <4 x4 xf32 >, %b: vector <4 x8 x16 xf32 >) {
338- // CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
338+ // CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} { ffsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
339339 %1 = vector.insert_strided_slice %a , %b {offsets = [2 , 2 , 2 ], strides = [1 , 1 ]} : vector <4 x4 xf32 > into vector <4 x8 x16 xf32 >
340340 return
341341}
@@ -543,20 +543,6 @@ func.func @vector_print_on_scalar(%arg0: i64) {
543543 return
544544}
545545
546- // CHECK-LABEL: @shape_cast_valid_rank_reduction
547- func.func @shape_cast_valid_rank_reduction (%arg0 : vector <5 x1 x3 x2 xf32 >) {
548- // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<2x15xf32>
549- %0 = vector.shape_cast %arg0 : vector <5 x1 x3 x2 xf32 > to vector <2 x15 xf32 >
550- return
551- }
552-
553- // CHECK-LABEL: @shape_cast_valid_rank_expansion
554- func.func @shape_cast_valid_rank_expansion (%arg0 : vector <15 x2 xf32 >) {
555- // CHECK: vector.shape_cast %{{.*}} : vector<15x2xf32> to vector<5x2x3x1xf32>
556- %0 = vector.shape_cast %arg0 : vector <15 x2 xf32 > to vector <5 x2 x3 x1 xf32 >
557- return
558- }
559-
560546// CHECK-LABEL: @shape_cast
561547func.func @shape_cast (%arg0 : vector <5 x1 x3 x2 xf32 >,
562548 %arg1 : vector <8 x1 xf32 >,
@@ -578,6 +564,17 @@ func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
578564 return %0 , %1 , %2 , %3 : vector <15 x2 xf32 >, vector <8 xf32 >, vector <16 xf32 >, vector <16 x1 xf32 >
579565}
580566
567+ // A vector.shape_cast can cast between any 2 shapes as long as the
568+ // number of elements is preserved. For those familiar with the tensor
569+ // dialect: this behaviour is like the tensor.reshape operation, i.e.
570+ // less restrictive than tensor.collapse_shape and tensor.expand_shape
571+ // CHECK-LABEL: @shape_cast_general_reshape
572+ func.func @shape_cast_general_reshape (%arg0 : vector <2 x3 xf32 >) -> (vector <3 x1 x2 xf32 >) {
573+ // CHECK: vector.shape_cast %{{.*}} : vector<2x3xf32> to vector<3x1x2xf32>
574+ %0 = vector.shape_cast %arg0 : vector <2 x3 xf32 > to vector <3 x1 x2 xf32 >
575+ return %0
576+ }
577+
581578// CHECK-LABEL: @shape_cast_0d
582579func.func @shape_cast_0d (%arg0 : vector <1 x1 x1 x1 xf32 >) -> (vector <1 x1 x1 x1 xf32 >) {
583580
0 commit comments