@@ -2882,8 +2882,21 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2882
2882
}
2883
2883
};
2884
2884
2885
+ // Return the broadcasted dimensions. Including broadcasts in the leading
2886
+ // dimensions and broadcasts through unit dimension (i.e. dim-1).
2887
+ static BitVector getBroadcastedDims (ArrayRef<int64_t > srcShape,
2888
+ ArrayRef<int64_t > destShape) {
2889
+ assert (destShape.size () >= srcShape.size ());
2890
+ BitVector broadcastedDims (destShape.size ());
2891
+ broadcastedDims.set (0 , destShape.size () - srcShape.size ());
2892
+ auto unitDims = computeBroadcastedUnitDims (srcShape, destShape);
2893
+ for (int64_t dim : unitDims)
2894
+ broadcastedDims.set (dim);
2895
+ return broadcastedDims;
2896
+ }
2897
+
2885
2898
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2886
- // with broadcast's result type.
2899
+ // with broadcast's result type and the broadcasted dimensions are the same .
2887
2900
struct FoldBroadcastOfShapeCast : public OpRewritePattern <BroadcastOp> {
2888
2901
using OpRewritePattern::OpRewritePattern;
2889
2902
@@ -2895,10 +2908,28 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
2895
2908
2896
2909
VectorType srcType = srcShapeCast.getSourceVectorType ();
2897
2910
VectorType destType = broadcastOp.getResultVectorType ();
2911
+ // Check type compatibility.
2898
2912
if (vector::isBroadcastableTo (srcType, destType) !=
2899
2913
BroadcastableToResult::Success)
2900
2914
return failure ();
2901
2915
2916
+ // Given
2917
+ // ```
2918
+ // %s = shape_cast(%x)
2919
+ // %b = broadcast(%s)
2920
+ // ```
2921
+ // If we want to fold %x into %b, the broadcasted dimensions from %x to
2922
+ // %b has to be the same as that of from %s to %b.
2923
+ ArrayRef<int64_t > shapecastShape =
2924
+ srcShapeCast.getResultVectorType ().getShape ();
2925
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
2926
+ ArrayRef<int64_t > destShape = destType.getShape ();
2927
+ BitVector origBroadcastedDims =
2928
+ getBroadcastedDims (shapecastShape, destShape);
2929
+ BitVector newBroadcastedDims = getBroadcastedDims (srcShape, destShape);
2930
+ if (newBroadcastedDims != origBroadcastedDims)
2931
+ return failure ();
2932
+
2902
2933
rewriter.replaceOpWithNewOp <BroadcastOp>(broadcastOp, destType,
2903
2934
srcShapeCast.getSource ());
2904
2935
return success ();
0 commit comments