Skip to content

Commit b4e8b8e

Browse files
mshockwavebanach-spacenewling
authored
[mlir][vector] Canonicalize broadcast of shape_cast (#150523)
Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is compatible with broadcast's result type and the shape_cast only adds or removes ones in the leading dimensions. --------- Co-authored-by: Andrzej Warzyński <[email protected]> Co-authored-by: James Newling <[email protected]>
1 parent 0419b45 commit b4e8b8e

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,9 +2841,47 @@ LogicalResult BroadcastOp::verify() {
28412841
llvm_unreachable("unexpected vector.broadcast op error");
28422842
}
28432843

2844+
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2845+
// with broadcast's result type and shape_cast only adds or removes ones in the
2846+
// leading dimensions.
2847+
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
2848+
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
2849+
if (!srcShapeCast)
2850+
return failure();
2851+
2852+
VectorType srcType = srcShapeCast.getSourceVectorType();
2853+
VectorType destType = broadcastOp.getResultVectorType();
2854+
// Check type compatibility.
2855+
if (vector::isBroadcastableTo(srcType, destType) !=
2856+
BroadcastableToResult::Success)
2857+
return failure();
2858+
2859+
ArrayRef<int64_t> srcShape = srcType.getShape();
2860+
ArrayRef<int64_t> shapecastShape =
2861+
srcShapeCast.getResultVectorType().getShape();
2862+
// Trailing dimensions should be the same if shape_cast only alters the
2863+
// leading dimensions.
2864+
unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
2865+
if (!llvm::equal(srcShape.take_back(numTrailingDims),
2866+
shapecastShape.take_back(numTrailingDims)))
2867+
return failure();
2868+
2869+
assert(all_of(srcShape.drop_back(numTrailingDims),
2870+
[](int64_t E) { return E == 1; }) &&
2871+
all_of(shapecastShape.drop_back(numTrailingDims),
2872+
[](int64_t E) { return E == 1; }) &&
2873+
"ill-formed shape_cast");
2874+
2875+
broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
2876+
return success();
2877+
}
2878+
28442879
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
28452880
if (getSourceType() == getResultVectorType())
28462881
return getSource();
2882+
if (succeeded(foldBroadcastOfShapeCast(*this)))
2883+
return getResult();
2884+
28472885
if (!adaptor.getSource())
28482886
return {};
28492887
auto vectorType = getResultVectorType();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,106 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
11681168

11691169
// -----
11701170

1171+
// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
1172+
// CHECK-NOT: vector.shape_cast
1173+
// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
1174+
func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
1175+
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
1176+
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
1177+
return %1 : vector<32x2xf32>
1178+
}
1179+
1180+
// -----
1181+
1182+
// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(
1183+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> {
1184+
// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32>
1185+
// CHECK: return %[[VAL_0]] : vector<32x2x1xf32>
1186+
// CHECK: }
1187+
func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(%arg0 : vector<2x1xf32>) -> vector<32x2x1xf32> {
1188+
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
1189+
%1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x1xf32>
1190+
return %1 : vector<32x2x1xf32>
1191+
}
1192+
1193+
// -----
1194+
1195+
// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(
1196+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> {
1197+
// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32>
1198+
// CHECK: return %[[VAL_0]] : vector<32x2x4xf32>
1199+
// CHECK: }
1200+
func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(%arg0 : vector<2x1xf32>) -> vector<32x2x4xf32> {
1201+
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
1202+
%1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x4xf32>
1203+
return %1 : vector<32x2x4xf32>
1204+
}
1205+
1206+
// -----
1207+
1208+
// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(
1209+
// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> {
1210+
// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32>
1211+
// CHECK: return %[[VAL_0]] : vector<32x2xf32>
1212+
// CHECK: }
1213+
func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(%arg0 : vector<1x2xf32>) -> vector<32x2xf32> {
1214+
%0 = vector.shape_cast %arg0 : vector<1x2xf32> to vector<2xf32>
1215+
%1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
1216+
return %1 : vector<32x2xf32>
1217+
}
1218+
1219+
// -----
1220+
1221+
// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
1222+
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
1223+
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
1224+
func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
1225+
%0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
1226+
%1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
1227+
return %1 : vector<2x4x16xf32>
1228+
}
1229+
1230+
// -----
1231+
1232+
// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims
1233+
// CHECK: vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32>
1234+
// CHECK: vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32>
1235+
func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%arg0 : vector<2x1xf32>) -> vector<2x2xf32> {
1236+
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2xf32>
1237+
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<2x2xf32>
1238+
return %1 : vector<2x2xf32>
1239+
}
1240+
1241+
// -----
1242+
1243+
// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(
1244+
// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> {
1245+
// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
1246+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32>
1247+
// CHECK: return %[[VAL_1]] : vector<2x4xf32>
1248+
// CHECK: }
1249+
func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(%arg0 : vector<2xf32>) -> vector<2x4xf32> {
1250+
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32>
1251+
%1 = vector.broadcast %0 : vector<2x1xf32> to vector<2x4xf32>
1252+
return %1 : vector<2x4xf32>
1253+
}
1254+
1255+
// -----
1256+
1257+
// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(
1258+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> {
1259+
// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32>
1260+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32>
1261+
// CHECK: return %[[VAL_1]] : vector<32x2xf32>
1262+
// CHECK: }
1263+
func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(%arg0 : vector<2x1xf32>) -> vector<32x2xf32> {
1264+
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<2xf32>
1265+
%1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
1266+
return %1 : vector<32x2xf32>
1267+
}
1268+
1269+
// -----
1270+
11711271
// CHECK-LABEL: fold_vector_transfer_masks
11721272
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
11731273
// CHECK: %[[C0:.+]] = arith.constant 0 : index

0 commit comments

Comments
 (0)