Skip to content

Commit 4d1bc8a

Browse files
committed
improve test of general shape_cast reshape
1 parent b528b97 commit 4d1bc8a

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ func.func @outerproduct_scalable(%arg0 : vector<[4]xf32>, %arg1 : vector<[8]xf32
335335

336336
// CHECK-LABEL: @insert_strided_slice
337337
func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
338-
// CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
338+
// CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} { ffsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
339339
%1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
340340
return
341341
}
@@ -543,20 +543,6 @@ func.func @vector_print_on_scalar(%arg0: i64) {
543543
return
544544
}
545545

546-
// CHECK-LABEL: @shape_cast_valid_rank_reduction
547-
func.func @shape_cast_valid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
548-
// CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<2x15xf32>
549-
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
550-
return
551-
}
552-
553-
// CHECK-LABEL: @shape_cast_valid_rank_expansion
554-
func.func @shape_cast_valid_rank_expansion(%arg0 : vector<15x2xf32>) {
555-
// CHECK: vector.shape_cast %{{.*}} : vector<15x2xf32> to vector<5x2x3x1xf32>
556-
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
557-
return
558-
}
559-
560546
// CHECK-LABEL: @shape_cast
561547
func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
562548
%arg1 : vector<8x1xf32>,
@@ -578,6 +564,17 @@ func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
578564
return %0, %1, %2, %3 : vector<15x2xf32>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>
579565
}
580566

567+
// A vector.shape_cast can cast between any 2 shapes as long as the
568+
// number of elements is preserved. For those familiar with the tensor
569+
// dialect: this behaviour is like the tensor.reshape operation, i.e.
570+
// less restrictive than tensor.collapse_shape and tensor.expand_shape
571+
// CHECK-LABEL: @shape_cast_general_reshape
572+
func.func @shape_cast_general_reshape(%arg0 : vector<2x3xf32>) -> (vector<3x1x2xf32>) {
573+
// CHECK: vector.shape_cast %{{.*}} : vector<2x3xf32> to vector<3x1x2xf32>
574+
%0 = vector.shape_cast %arg0 : vector<2x3xf32> to vector<3x1x2xf32>
575+
return %0
576+
}
577+
581578
// CHECK-LABEL: @shape_cast_0d
582579
func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
583580

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
func.func @test_extract_strided_slice_1(%arg0 : vector<4x4x4x4xf32>) -> vector<2x3x4x4xf32> {
2+
%0 = vector.extract_strided_slice %arg0 { sizes = [2, 3], strides = [1, 1], offsets = [1, 1]}
3+
: vector<4x4x4x4xf32> to vector<2x3x4x4xf32>
4+
return %0 : vector<2x3x4x4xf32>
5+
}
6+
7+
8+
func.func @test(%arg0 : vector<7xf32>, %arg1 : vector<3xf32>) -> vector<4xf32> {
9+
10+
%out = vector.shuffle %arg0, %arg1 [0, 4, 5, 1]
11+
: vector<7xf32>, vector<3xf32>
12+
return %out : vector<4xf32>
13+
14+
}

0 commit comments

Comments
 (0)