Skip to content

Commit 0cf5cc1

Browse files
committed
fixup! Fix invalid folding on mismatching broadcast dimensions
1 parent 5517462 commit 0cf5cc1

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2882,8 +2882,21 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
28822882
}
28832883
};
28842884

2885+
// Return the broadcasted dimensions. Including broadcasts in the leading
2886+
// dimensions and broadcasts through unit dimension (i.e. dim-1).
2887+
static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
2888+
ArrayRef<int64_t> destShape) {
2889+
assert(destShape.size() >= srcShape.size());
2890+
BitVector broadcastedDims(destShape.size());
2891+
broadcastedDims.set(0, destShape.size() - srcShape.size());
2892+
auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
2893+
for (int64_t dim : unitDims)
2894+
broadcastedDims.set(dim);
2895+
return broadcastedDims;
2896+
}
2897+
28852898
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2886-
// with broadcast's result type.
2899+
// with broadcast's result type and the broadcasted dimensions are the same.
28872900
struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
28882901
using OpRewritePattern::OpRewritePattern;
28892902

@@ -2895,10 +2908,28 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
28952908

28962909
VectorType srcType = srcShapeCast.getSourceVectorType();
28972910
VectorType destType = broadcastOp.getResultVectorType();
2911+
// Check type compatibility.
28982912
if (vector::isBroadcastableTo(srcType, destType) !=
28992913
BroadcastableToResult::Success)
29002914
return failure();
29012915

2916+
// Given
2917+
// ```
2918+
// %s = shape_cast(%x)
2919+
// %b = broadcast(%s)
2920+
// ```
2921+
// If we want to fold %x into %b, the broadcasted dimensions from %x to
2922+
// %b has to be the same as that of from %s to %b.
2923+
ArrayRef<int64_t> shapecastShape =
2924+
srcShapeCast.getResultVectorType().getShape();
2925+
ArrayRef<int64_t> srcShape = srcType.getShape();
2926+
ArrayRef<int64_t> destShape = destType.getShape();
2927+
BitVector origBroadcastedDims =
2928+
getBroadcastedDims(shapecastShape, destShape);
2929+
BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape);
2930+
if (newBroadcastedDims != origBroadcastedDims)
2931+
return failure();
2932+
29022933
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
29032934
srcShapeCast.getSource());
29042935
return success();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>)
11801180
// -----
11811181

11821182
// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
1183-
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
1183+
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
11841184
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
11851185
func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
11861186
%0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
@@ -1190,6 +1190,17 @@ func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vecto
11901190

11911191
// -----
11921192

1193+
// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims
1194+
// CHECK: vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32>
1195+
// CHECK: vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32>
1196+
func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%arg0 : vector<2x1xf32>) -> vector<2x2xf32> {
1197+
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2xf32>
1198+
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<2x2xf32>
1199+
return %1 : vector<2x2xf32>
1200+
}
1201+
1202+
// -----
1203+
11931204
// CHECK-LABEL: fold_vector_transfer_masks
11941205
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
11951206
// CHECK: %[[C0:.+]] = arith.constant 0 : index

0 commit comments

Comments
 (0)