@@ -2882,8 +2882,21 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
28822882 }
28832883};
28842884
2885+ // Return the broadcasted dimensions. Including broadcasts in the leading
2886+ // dimensions and broadcasts through unit dimension (i.e. dim-1).
2887+ static BitVector getBroadcastedDims (ArrayRef<int64_t > srcShape,
2888+ ArrayRef<int64_t > destShape) {
2889+ assert (destShape.size () >= srcShape.size ());
2890+ BitVector broadcastedDims (destShape.size ());
2891+ broadcastedDims.set (0 , destShape.size () - srcShape.size ());
2892+ auto unitDims = computeBroadcastedUnitDims (srcShape, destShape);
2893+ for (int64_t dim : unitDims)
2894+ broadcastedDims.set (dim);
2895+ return broadcastedDims;
2896+ }
2897+
28852898// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2886- // with broadcast's result type.
2899+ // with broadcast's result type and the broadcasted dimensions are the same .
28872900struct FoldBroadcastOfShapeCast : public OpRewritePattern <BroadcastOp> {
28882901 using OpRewritePattern::OpRewritePattern;
28892902
@@ -2895,10 +2908,28 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
28952908
28962909 VectorType srcType = srcShapeCast.getSourceVectorType ();
28972910 VectorType destType = broadcastOp.getResultVectorType ();
2911+ // Check type compatibility.
28982912 if (vector::isBroadcastableTo (srcType, destType) !=
28992913 BroadcastableToResult::Success)
29002914 return failure ();
29012915
2916+ // Given
2917+ // ```
2918+ // %s = shape_cast(%x)
2919+ // %b = broadcast(%s)
2920+ // ```
2921+ // If we want to fold %x into %b, the broadcasted dimensions from %x to
2922+ // %b has to be the same as that of from %s to %b.
2923+ ArrayRef<int64_t > shapecastShape =
2924+ srcShapeCast.getResultVectorType ().getShape ();
2925+ ArrayRef<int64_t > srcShape = srcType.getShape ();
2926+ ArrayRef<int64_t > destShape = destType.getShape ();
2927+ BitVector origBroadcastedDims =
2928+ getBroadcastedDims (shapecastShape, destShape);
2929+ BitVector newBroadcastedDims = getBroadcastedDims (srcShape, destShape);
2930+ if (newBroadcastedDims != origBroadcastedDims)
2931+ return failure ();
2932+
29022933 rewriter.replaceOpWithNewOp <BroadcastOp>(broadcastOp, destType,
29032934 srcShapeCast.getSource ());
29042935 return success ();
0 commit comments