Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 64 additions & 64 deletions mlir/test/Dialect/Vector/linearize.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s

// CHECK-LABEL: test_linearize
// CHECK-LABEL: elementwise_constant
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
func.func @elementwise_constant(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {

// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
// CHECK: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
Expand All @@ -21,8 +21,8 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {

// -----

// CHECK-LABEL: test_linearize_poison
func.func @test_linearize_poison() -> vector<2x2xf32> {
// CHECK-LABEL: poison
func.func @poison() -> vector<2x2xf32> {

// CHECK: %[[POISON:.*]] = ub.poison : vector<4xf32>
// CHECK: %[[RES:.*]] = vector.shape_cast %[[POISON]] : vector<4xf32> to vector<2x2xf32>
Expand All @@ -34,9 +34,9 @@ func.func @test_linearize_poison() -> vector<2x2xf32> {

// -----

// CHECK-LABEL: test_partial_linearize
// CHECK-LABEL: partial_linearize
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {
func.func @partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {

// CHECK-DAG: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
// CHECK-DAG: %[[ARG2:.*]] = vector.shape_cast %[[ORIG_ARG2]] : vector<4x4xf32> to vector<16xf32>
Expand Down Expand Up @@ -68,8 +68,8 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>

// vectorizable operation (arith.mulf) with tensor result types.

// CHECK-LABEL: test_tensor_no_linearize
func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) {
// CHECK-LABEL: tensor_no_linearize
func.func @tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) {

// CHECK: %[[MULF:.*]] = arith.mulf %arg0, %arg1 : tensor<2x2xf32>
%0 = arith.mulf %arg0, %arg1 : tensor<2x2xf32>
Expand All @@ -79,9 +79,9 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3

// -----

// CHECK-LABEL: func.func @test_scalable_linearize(
// CHECK-LABEL: func.func @scalable_linearize(
// CHECK-SAME: %[[ARG_0:.*]]: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
func.func @scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {

// CHECK: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
// CHECK: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
Expand All @@ -100,9 +100,9 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32

// -----

// CHECK-LABEL: func.func @test_scalable_no_linearize(
// CHECK-LABEL: func.func @scalable_no_linearize(
// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
func.func @scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {

// CHECK: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
%0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
Expand All @@ -119,8 +119,8 @@ func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x

// -----

// CHECK-LABEL: func.func @test_0d_vector
func.func @test_0d_vector() -> vector<f32> {
// CHECK-LABEL: func.func @0d_vector
func.func @0d_vector() -> vector<f32> {

// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
%0 = arith.constant dense<0.0> : vector<f32>
Expand All @@ -131,9 +131,9 @@ func.func @test_0d_vector() -> vector<f32> {

// -----

// CHECK-LABEL: test_extract_strided_slice_2D
// CHECK-LABEL: extract_strided_slice_2D
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
func.func @test_extract_strided_slice_2D(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
func.func @extract_strided_slice_2D(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {

// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
Expand All @@ -147,9 +147,9 @@ func.func @test_extract_strided_slice_2D(%arg0 : vector<4x8xf32>) -> vector<2x2x

// -----

// CHECK-LABEL: func.func @test_extract_strided_slice_2D_scalable(
// CHECK-LABEL: func.func @extract_strided_slice_2D_scalable(
// CHECK-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
func.func @test_extract_strided_slice_2D_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
func.func @extract_strided_slice_2D_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {

// CHECK-NOT: vector.shuffle
// CHECK-NOT: vector.shape_cast
Expand All @@ -162,9 +162,9 @@ func.func @test_extract_strided_slice_2D_scalable(%arg0: vector<4x[8]xf32>) -> v

// -----

// CHECK-LABEL: test_extract_strided_slice_3D
// CHECK-LABEL: extract_strided_slice_3D
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
func.func @test_extract_strided_slice_3D(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
func.func @extract_strided_slice_3D(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {

// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
Expand Down Expand Up @@ -248,9 +248,9 @@ func.func @negative_insert_strided_slice_scalable(%arg0 : vector<1x[2]xi8>, %arg

// -----

// CHECK-LABEL: test_vector_shuffle
// CHECK-LABEL: shuffle
// CHECK-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
func.func @shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {

// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
Expand All @@ -264,9 +264,9 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -

// -----

// CHECK-LABEL: test_vector_extract_scalar
// CHECK-LABEL: extract_scalar
// CHECK-SAME: (%[[ARG:.*]]: vector<2x4xi32>) -> i32 {
func.func @test_vector_extract_scalar(%arg0 : vector<2x4xi32>) -> i32 {
func.func @extract_scalar(%arg0 : vector<2x4xi32>) -> i32 {

// CHECK: %[[SRC_1D:.*]] = vector.shape_cast %[[ARG]] : vector<2x4xi32> to vector<8xi32>
// CHECK: %[[EXTRACT_1D:.*]] = vector.extract %[[SRC_1D]][6] : i32 from vector<8xi32>
Expand All @@ -277,9 +277,9 @@ func.func @test_vector_extract_scalar(%arg0 : vector<2x4xi32>) -> i32 {

// -----

// CHECK-LABEL: test_vector_extract
// CHECK-LABEL: extract
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
func.func @extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {

// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
Expand All @@ -292,9 +292,9 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {

// -----

// CHECK-LABEL: func.func @test_vector_extract_scalable(
// CHECK-LABEL: func.func @extract_scalable(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
func.func @extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {

// CHECK-NOT: vector.shuffle
// CHECK-NOT: vector.shape_cast
Expand All @@ -307,9 +307,9 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x

// -----

// CHECK-LABEL: test_vector_insert_scalar
// CHECK-LABEL: insert_scalar
// CHECK-SAME: (%[[DEST:.*]]: vector<2x4xf32>, %[[SRC:.*]]: f32) -> vector<2x4xf32> {
func.func @test_vector_insert_scalar(%arg0: vector<2x4xf32>, %arg1: f32) -> vector<2x4xf32> {
func.func @insert_scalar(%arg0: vector<2x4xf32>, %arg1: f32) -> vector<2x4xf32> {

// CHECK: %[[DEST_1D:.*]] = vector.shape_cast %[[DEST]] : vector<2x4xf32> to vector<8xf32>
// CHECK: %[[INSERT_1D:.*]] = vector.insert %[[SRC]], %[[DEST_1D]] [6] : f32 into vector<8xf32>
Expand All @@ -321,9 +321,9 @@ func.func @test_vector_insert_scalar(%arg0: vector<2x4xf32>, %arg1: f32) -> vect

// -----

// CHECK-LABEL: test_vector_insert
// CHECK-LABEL: insert
// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
func.func @insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {

// CHECK-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
// CHECK-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
Expand All @@ -339,9 +339,9 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)

// -----

// CHECK-LABEL: func.func @test_vector_insert_scalable(
// CHECK-LABEL: func.func @insert_scalable(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x8x[4]xf32>, %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
func.func @insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {

// CHECK-NOT: vector.shuffle
// CHECK-NOT: vector.shape_cast
Expand All @@ -354,9 +354,9 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector

// -----

// CHECK-LABEL: test_vector_bitcast
// CHECK-LABEL: bitcast
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
func.func @bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {

// CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x4xf32> to vector<16xf32>
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<16xf32> to vector<32xf16>
Expand All @@ -367,9 +367,9 @@ func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {

// -----

// CHECK-LABEL: test_vector_bitcast
// CHECK-LABEL: bitcast
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x2xf32>
func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
func.func @bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {

// CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32>
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
Expand All @@ -380,9 +380,9 @@ func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {

// -----

// CHECK-LABEL: test_vector_bitcast
// CHECK-LABEL: bitcast
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[2]xf32>
func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
func.func @bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {

// CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32>
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
Expand All @@ -393,9 +393,9 @@ func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {

// -----

// CHECK-LABEL: test_vector_bitcast
// CHECK-LABEL: bitcast
// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
func.func @bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {

// CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32>
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
Expand All @@ -406,8 +406,8 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {

// -----

// CHECK-LABEL: test_linearize_across_for
func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
// CHECK-LABEL: across_for
func.func @across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
%0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -428,9 +428,9 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {

// -----

// CHECK-LABEL: linearize_vector_broadcast_scalar_source
// CHECK-LABEL: broadcast_scalar_source
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> {
func.func @broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> {

// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
Expand All @@ -441,9 +441,9 @@ func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi3

// -----

// CHECK-LABEL: linearize_vector_broadcast_rank_two_source
// CHECK-LABEL: broadcast_rank_two_source
// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32>
func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> {
func.func @broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> {

// CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32>
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32>
Expand All @@ -455,9 +455,9 @@ func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) ->

// -----

// CHECK-LABEL: linearize_scalable_vector_broadcast
// CHECK-LABEL: scalable_broadcast
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> {
func.func @scalable_broadcast(%arg0: i32) -> vector<4x[2]xi32> {

// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<[8]xi32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<[8]xi32> to vector<4x[2]xi32>
Expand All @@ -469,9 +469,9 @@ func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32>

// -----

// CHECK-LABEL: linearize_create_mask
// CHECK-LABEL: create_mask
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
func.func @create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
Expand All @@ -485,17 +485,17 @@ func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1
}

// -----
// CHECK-LABEL: linearize_scalable_create_mask
func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vector<1x[16]xi1> {
// CHECK-LABEL: scalable_create_mask
func.func @scalable_create_mask(%arg0 : index, %arg1 : index) -> vector<1x[16]xi1> {

// CHECK: %[[MASK_1D:.*]] = vector.create_mask {{%.*}} : vector<[16]xi1>
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
return %0 : vector<1x[16]xi1>
}

// CHECK-LABEL: linearize_vector_load
// CHECK-LABEL: load
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x4xf32>
func.func @linearize_vector_load(%arg0: memref<2x8xf32>) -> vector<1x4xf32> {
func.func @load(%arg0: memref<2x8xf32>) -> vector<1x4xf32> {
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
Expand All @@ -505,9 +505,9 @@ func.func @linearize_vector_load(%arg0: memref<2x8xf32>) -> vector<1x4xf32> {
return %0 : vector<1x4xf32>
}

// CHECK-LABEL: linearize_vector_store
// CHECK-LABEL: store
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
func.func @linearize_vector_store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32>) {
func.func @store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32>) {
// CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32>
Expand All @@ -516,9 +516,9 @@ func.func @linearize_vector_store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32>
return
}

// CHECK-LABEL: linearize_vector_load_scalable
// CHECK-LABEL: load_scalable
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x[4]xf32>
func.func @linearize_vector_load_scalable(%arg0: memref<2x8xf32>) -> vector<1x[4]xf32> {
func.func @load_scalable(%arg0: memref<2x8xf32>) -> vector<1x[4]xf32> {
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<[4]xf32> to vector<1x[4]xf32>
Expand All @@ -528,9 +528,9 @@ func.func @linearize_vector_load_scalable(%arg0: memref<2x8xf32>) -> vector<1x[4
return %0 : vector<1x[4]xf32>
}

// CHECK-LABEL: linearize_vector_store_scalable
// CHECK-LABEL: store_scalable
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x[4]xf32>)
func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector<1x[4]xf32>) {
func.func @store_scalable(%arg0: memref<2x8xf32>, %arg1: vector<1x[4]xf32>) {
// CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x[4]xf32> to vector<[4]xf32>
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32>
Expand All @@ -543,9 +543,9 @@ func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector

// Test pattern LinearizeVectorFromElements.

// CHECK-LABEL: test_vector_from_elements
// CHECK-LABEL: from_elements
// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32
func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
func.func @from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
// CHECK: %[[FROM_ELEMENTS:.*]] = vector.from_elements %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]] : vector<4xf32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[FROM_ELEMENTS]] : vector<4xf32> to vector<2x2xf32>
// CHECK: return %[[CAST]] : vector<2x2xf32>
Expand Down
Loading