@@ -1168,6 +1168,106 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
11681168
11691169// -----
11701170
1171+ // CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
1172+ // CHECK-NOT: vector.shape_cast
1173+ // CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
1174+ func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim (%arg0 : vector <2 xf32 >) -> vector <32 x2 xf32 > {
1175+ %0 = vector.shape_cast %arg0 : vector <2 xf32 > to vector <1 x2 xf32 >
1176+ %1 = vector.broadcast %0 : vector <1 x2 xf32 > to vector <32 x2 xf32 >
1177+ return %1 : vector <32 x2 xf32 >
1178+ }
1179+
1180+ // -----
1181+
1182+ // CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(
1183+ // CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> {
1184+ // CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32>
1185+ // CHECK: return %[[VAL_0]] : vector<32x2x1xf32>
1186+ // CHECK: }
1187+ func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2 (%arg0 : vector <2 x1 xf32 >) -> vector <32 x2 x1 xf32 > {
1188+ %0 = vector.shape_cast %arg0 : vector <2 x1 xf32 > to vector <1 x2 x1 xf32 >
1189+ %1 = vector.broadcast %0 : vector <1 x2 x1 xf32 > to vector <32 x2 x1 xf32 >
1190+ return %1 : vector <32 x2 x1 xf32 >
1191+ }
1192+
1193+ // -----
1194+
1195+ // CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(
1196+ // CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> {
1197+ // CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32>
1198+ // CHECK: return %[[VAL_0]] : vector<32x2x4xf32>
1199+ // CHECK: }
1200+ func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3 (%arg0 : vector <2 x1 xf32 >) -> vector <32 x2 x4 xf32 > {
1201+ %0 = vector.shape_cast %arg0 : vector <2 x1 xf32 > to vector <1 x2 x1 xf32 >
1202+ %1 = vector.broadcast %0 : vector <1 x2 x1 xf32 > to vector <32 x2 x4 xf32 >
1203+ return %1 : vector <32 x2 x4 xf32 >
1204+ }
1205+
1206+ // -----
1207+
1208+ // CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(
1209+ // CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> {
1210+ // CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32>
1211+ // CHECK: return %[[VAL_0]] : vector<32x2xf32>
1212+ // CHECK: }
1213+ func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim (%arg0 : vector <1 x2 xf32 >) -> vector <32 x2 xf32 > {
1214+ %0 = vector.shape_cast %arg0 : vector <1 x2 xf32 > to vector <2 xf32 >
1215+ %1 = vector.broadcast %0 : vector <2 xf32 > to vector <32 x2 xf32 >
1216+ return %1 : vector <32 x2 xf32 >
1217+ }
1218+
1219+ // -----
1220+
1221+ // CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
1222+ // CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
1223+ // CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
1224+ func.func @negative_canonicalize_shapecast_broadcast_invalid_shape (%arg0 : vector <64 xf32 >) -> vector <2 x4 x16 xf32 > {
1225+ %0 = vector.shape_cast %arg0 : vector <64 xf32 > to vector <4 x16 xf32 >
1226+ %1 = vector.broadcast %0 : vector <4 x16 xf32 > to vector <2 x4 x16 xf32 >
1227+ return %1 : vector <2 x4 x16 xf32 >
1228+ }
1229+
1230+ // -----
1231+
1232+ // CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims
1233+ // CHECK: vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32>
1234+ // CHECK: vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32>
1235+ func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims (%arg0 : vector <2 x1 xf32 >) -> vector <2 x2 xf32 > {
1236+ %0 = vector.shape_cast %arg0 : vector <2 x1 xf32 > to vector <1 x2 xf32 >
1237+ %1 = vector.broadcast %0 : vector <1 x2 xf32 > to vector <2 x2 xf32 >
1238+ return %1 : vector <2 x2 xf32 >
1239+ }
1240+
1241+ // -----
1242+
1243+ // CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(
1244+ // CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> {
1245+ // CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
1246+ // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32>
1247+ // CHECK: return %[[VAL_1]] : vector<2x4xf32>
1248+ // CHECK: }
1249+ func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim (%arg0 : vector <2 xf32 >) -> vector <2 x4 xf32 > {
1250+ %0 = vector.shape_cast %arg0 : vector <2 xf32 > to vector <2 x1 xf32 >
1251+ %1 = vector.broadcast %0 : vector <2 x1 xf32 > to vector <2 x4 xf32 >
1252+ return %1 : vector <2 x4 xf32 >
1253+ }
1254+
1255+ // -----
1256+
1257+ // CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(
1258+ // CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> {
1259+ // CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32>
1260+ // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32>
1261+ // CHECK: return %[[VAL_1]] : vector<32x2xf32>
1262+ // CHECK: }
1263+ func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim (%arg0 : vector <2 x1 xf32 >) -> vector <32 x2 xf32 > {
1264+ %0 = vector.shape_cast %arg0 : vector <2 x1 xf32 > to vector <2 xf32 >
1265+ %1 = vector.broadcast %0 : vector <2 xf32 > to vector <32 x2 xf32 >
1266+ return %1 : vector <32 x2 xf32 >
1267+ }
1268+
1269+ // -----
1270+
11711271// CHECK-LABEL: fold_vector_transfer_masks
11721272func.func @fold_vector_transfer_masks (%A: memref <?x?xf32 >) -> (vector <4 x8 xf32 >, vector <4 x[4 ]xf32 >) {
11731273 // CHECK: %[[C0:.+]] = arith.constant 0 : index
0 commit comments