@@ -1168,17 +1168,56 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
11681168
11691169// -----
11701170
1171- // CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
1171+ // CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
11721172// CHECK-NOT: vector.shape_cast
11731173// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
1174- func.func @canonicalize_shapecast_broadcast_to_broadcast (%arg0 : vector <2 xf32 >) -> vector <32 x2 xf32 > {
1174+ func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim (%arg0 : vector <2 xf32 >) -> vector <32 x2 xf32 > {
11751175 %0 = vector.shape_cast %arg0 : vector <2 xf32 > to vector <1 x2 xf32 >
11761176 %1 = vector.broadcast %0 : vector <1 x2 xf32 > to vector <32 x2 xf32 >
11771177 return %1 : vector <32 x2 xf32 >
11781178}
11791179
11801180// -----
11811181
1182+ // CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(
1183+ // CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> {
1184+ // CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32>
1185+ // CHECK: return %[[VAL_0]] : vector<32x2x1xf32>
1186+ // CHECK: }
1187+ func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2 (%arg0 : vector <2 x1 xf32 >) -> vector <32 x2 x1 xf32 > {
1188+ %0 = vector.shape_cast %arg0 : vector <2 x1 xf32 > to vector <1 x2 x1 xf32 >
1189+ %1 = vector.broadcast %0 : vector <1 x2 x1 xf32 > to vector <32 x2 x1 xf32 >
1190+ return %1 : vector <32 x2 x1 xf32 >
1191+ }
1192+
1193+ // -----
1194+
1195+ // CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(
1196+ // CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> {
1197+ // CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32>
1198+ // CHECK: return %[[VAL_0]] : vector<32x2x4xf32>
1199+ // CHECK: }
1200+ func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3 (%arg0 : vector <2 x1 xf32 >) -> vector <32 x2 x4 xf32 > {
1201+ %0 = vector.shape_cast %arg0 : vector <2 x1 xf32 > to vector <1 x2 x1 xf32 >
1202+ %1 = vector.broadcast %0 : vector <1 x2 x1 xf32 > to vector <32 x2 x4 xf32 >
1203+ return %1 : vector <32 x2 x4 xf32 >
1204+ }
1205+
1206+ // -----
1207+
1208+ // CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(
1209+ // CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> {
1210+ // CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32>
1211+ // CHECK: return %[[VAL_0]] : vector<32x2xf32>
1212+ // CHECK: }
1213+ func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim (%arg0 : vector <1 x2 xf32 >) -> vector <32 x2 xf32 > {
1214+ %0 = vector.shape_cast %arg0 : vector <1 x2 xf32 > to vector <2 xf32 >
1215+ %1 = vector.broadcast %0 : vector <2 xf32 > to vector <32 x2 xf32 >
1216+ return %1 : vector <32 x2 xf32 >
1217+ }
1218+
1219+ // -----
1220+
11821221// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
11831222// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
11841223// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
@@ -1201,6 +1240,34 @@ func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%a
12011240
12021241// -----
12031242
1243+ // CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(
1244+ // CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> {
1245+ // CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
1246+ // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32>
1247+ // CHECK: return %[[VAL_1]] : vector<2x4xf32>
1248+ // CHECK: }
1249+ func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim (%arg0 : vector <2 xf32 >) -> vector <2 x4 xf32 > {
1250+ %0 = vector.shape_cast %arg0 : vector <2 xf32 > to vector <2 x1 xf32 >
1251+ %1 = vector.broadcast %0 : vector <2 x1 xf32 > to vector <2 x4 xf32 >
1252+ return %1 : vector <2 x4 xf32 >
1253+ }
1254+
1255+ // -----
1256+
1257+ // CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(
1258+ // CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> {
1259+ // CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32>
1260+ // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32>
1261+ // CHECK: return %[[VAL_1]] : vector<32x2xf32>
1262+ // CHECK: }
1263+ func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim (%arg0 : vector <2 x1 xf32 >) -> vector <32 x2 xf32 > {
1264+ %0 = vector.shape_cast %arg0 : vector <2 x1 xf32 > to vector <2 xf32 >
1265+ %1 = vector.broadcast %0 : vector <2 xf32 > to vector <32 x2 xf32 >
1266+ return %1 : vector <32 x2 xf32 >
1267+ }
1268+
1269+ // -----
1270+
12041271// CHECK-LABEL: fold_vector_transfer_masks
12051272func.func @fold_vector_transfer_masks (%A: memref <?x?xf32 >) -> (vector <4 x8 xf32 >, vector <4 x[4 ]xf32 >) {
12061273 // CHECK: %[[C0:.+]] = arith.constant 0 : index
0 commit comments