11// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
22
3- // CHECK-LABEL: func.func @broadcast_scalar (
3+ // CHECK-LABEL: func.func @broadcast_scalar_with_bcast (
44// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
55// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
66// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
77// CHECK: return %[[BCAST]] : vector<1x4xindex>
8- // CHECK: }
98
10- func.func @broadcast_scalar ( %arg1: index , %arg2: index ) -> vector <1 x4 xindex > {
9+ func.func @broadcast_scalar_with_bcast ( %arg1: index , %arg2: index ) -> vector <1 x4 xindex > {
1110 %0 = vector.broadcast %arg1 : index to vector <1 x4 xindex >
1211 %1 = vector.broadcast %arg2 : index to vector <1 x4 xindex >
1312 %2 = arith.addi %0 , %1 : vector <1 x4 xindex >
@@ -16,20 +15,51 @@ func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
1615
1716// -----
1817
18+ // CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
19+ // CHECK-SAME: %[[ARG1:.*]]: index,
20+ // CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> {
21+ // CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
22+ // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
23+ // CHECK: return %[[BCAST]] : vector<1x4xindex>
24+ func.func @broadcast_scalar_with_bcast_and_splat ( %arg1: index , %arg2: index ) -> vector <1 x4 xindex > {
25+ %0 = vector.splat %arg1 : vector <1 x4 xindex >
26+ %1 = vector.broadcast %arg2 : index to vector <1 x4 xindex >
27+ %2 = arith.addi %0 , %1 : vector <1 x4 xindex >
28+ return %2 : vector <1 x4 xindex >
29+ }
30+
31+ // -----
32+
1933// CHECK-LABEL: func.func @broadcast_vector(
2034// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
2135// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
2236// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
2337// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
2438// CHECK: return %[[BCAST]] : vector<3x4xf32>
25- // CHECK: }
2639
2740func.func @broadcast_vector ( %arg1: vector <4 xf32 >, %arg2: vector <4 xf32 >) -> vector <3 x4 xf32 > {
2841 %arg1_bcast = vector.broadcast %arg1 : vector <4 xf32 > to vector <3 x4 xf32 >
2942 %arg2_bcast = vector.broadcast %arg2 : vector <4 xf32 > to vector <3 x4 xf32 >
3043 %2 = arith.addf %arg1_bcast , %arg2_bcast : vector <3 x4 xf32 >
3144 return %2 : vector <3 x4 xf32 >
3245}
46+
47+ // -----
48+
49+ // CHECK-LABEL: func.func @broadcast_scalar_and_vec(
50+ // CHECK-SAME: %[[ARG1:.*]]: index,
51+ // CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
52+ // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
53+ // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
54+ // CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
55+ // CHECK: return %[[ADD]] : vector<1x4xindex>
56+ func.func @broadcast_scalar_and_vec ( %arg1: index , %arg2: vector <4 xindex >) -> vector <1 x4 xindex > {
57+ %0 = vector.splat %arg1 : vector <1 x4 xindex >
58+ %1 = vector.broadcast %arg2 : vector <4 xindex > to vector <1 x4 xindex >
59+ %2 = arith.addi %0 , %1 : vector <1 x4 xindex >
60+ return %2 : vector <1 x4 xindex >
61+ }
62+
3363// -----
3464
3565// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
@@ -38,7 +68,6 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
3868// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
3969// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
4070// CHECK: return %[[ADD]] : vector<4xi32>
41- // CHECK: }
4271
4372func.func @broadcast_vector_and_scalar ( %arg1: i32 , %arg2: vector <4 xi32 >) -> vector <4 xi32 > {
4473 %arg1_bcast = vector.broadcast %arg1 : i32 to vector <4 xi32 >
0 commit comments