Skip to content

Commit 6755a75

Browse files
committed
fixup! Add more test cases
1 parent e370b81 commit 6755a75

File tree

1 file changed

+69
-2
lines changed

1 file changed

+69
-2
lines changed

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,17 +1168,56 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
11681168

11691169
// -----
11701170

1171-
// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
1171+
// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
11721172
// CHECK-NOT: vector.shape_cast
11731173
// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
1174-
func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
1174+
func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
11751175
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
11761176
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
11771177
return %1 : vector<32x2xf32>
11781178
}
11791179

11801180
// -----
11811181

1182+
// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(
1183+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> {
1184+
// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32>
1185+
// CHECK: return %[[VAL_0]] : vector<32x2x1xf32>
1186+
// CHECK: }
1187+
func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(%arg0 : vector<2x1xf32>) -> vector<32x2x1xf32> {
1188+
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
1189+
%1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x1xf32>
1190+
return %1 : vector<32x2x1xf32>
1191+
}
1192+
1193+
// -----
1194+
1195+
// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(
1196+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> {
1197+
// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32>
1198+
// CHECK: return %[[VAL_0]] : vector<32x2x4xf32>
1199+
// CHECK: }
1200+
func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(%arg0 : vector<2x1xf32>) -> vector<32x2x4xf32> {
1201+
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
1202+
%1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x4xf32>
1203+
return %1 : vector<32x2x4xf32>
1204+
}
1205+
1206+
// -----
1207+
1208+
// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(
1209+
// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> {
1210+
// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32>
1211+
// CHECK: return %[[VAL_0]] : vector<32x2xf32>
1212+
// CHECK: }
1213+
func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(%arg0 : vector<1x2xf32>) -> vector<32x2xf32> {
1214+
%0 = vector.shape_cast %arg0 : vector<1x2xf32> to vector<2xf32>
1215+
%1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
1216+
return %1 : vector<32x2xf32>
1217+
}
1218+
1219+
// -----
1220+
11821221
// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
11831222
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
11841223
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
@@ -1201,6 +1240,34 @@ func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%a
12011240

12021241
// -----
12031242

1243+
// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(
1244+
// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> {
1245+
// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
1246+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32>
1247+
// CHECK: return %[[VAL_1]] : vector<2x4xf32>
1248+
// CHECK: }
1249+
func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(%arg0 : vector<2xf32>) -> vector<2x4xf32> {
1250+
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32>
1251+
%1 = vector.broadcast %0 : vector<2x1xf32> to vector<2x4xf32>
1252+
return %1 : vector<2x4xf32>
1253+
}
1254+
1255+
// -----
1256+
1257+
// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(
1258+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> {
1259+
// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32>
1260+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32>
1261+
// CHECK: return %[[VAL_1]] : vector<32x2xf32>
1262+
// CHECK: }
1263+
func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(%arg0 : vector<2x1xf32>) -> vector<32x2xf32> {
1264+
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<2xf32>
1265+
%1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
1266+
return %1 : vector<32x2xf32>
1267+
}
1268+
1269+
// -----
1270+
12041271
// CHECK-LABEL: fold_vector_transfer_masks
12051272
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
12061273
// CHECK: %[[C0:.+]] = arith.constant 0 : index

0 commit comments

Comments
 (0)