@@ -2841,9 +2841,59 @@ LogicalResult BroadcastOp::verify() {
2841
2841
llvm_unreachable (" unexpected vector.broadcast op error" );
2842
2842
}
2843
2843
2844
+ // Return the broadcasted dimensions. Including broadcasts in the leading
2845
+ // dimensions and broadcasts through unit dimension (i.e. dim-1).
2846
+ static BitVector getBroadcastedDims (ArrayRef<int64_t > srcShape,
2847
+ ArrayRef<int64_t > destShape) {
2848
+ assert (destShape.size () >= srcShape.size ());
2849
+ BitVector broadcastedDims (destShape.size ());
2850
+ broadcastedDims.set (0 , destShape.size () - srcShape.size ());
2851
+ auto unitDims = computeBroadcastedUnitDims (srcShape, destShape);
2852
+ for (int64_t dim : unitDims)
2853
+ broadcastedDims.set (dim);
2854
+ return broadcastedDims;
2855
+ }
2856
+
2857
+ // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2858
+ // with broadcast's result type and the broadcasted dimensions are the same.
2859
+ static LogicalResult foldBroadcastOfShapeCast (BroadcastOp broadcastOp) {
2860
+ auto srcShapeCast = broadcastOp.getSource ().getDefiningOp <ShapeCastOp>();
2861
+ if (!srcShapeCast)
2862
+ return failure ();
2863
+
2864
+ VectorType srcType = srcShapeCast.getSourceVectorType ();
2865
+ VectorType destType = broadcastOp.getResultVectorType ();
2866
+ // Check type compatibility.
2867
+ if (vector::isBroadcastableTo (srcType, destType) !=
2868
+ BroadcastableToResult::Success)
2869
+ return failure ();
2870
+
2871
+ // Given
2872
+ // ```
2873
+ // %s = shape_cast(%x)
2874
+ // %b = broadcast(%s)
2875
+ // ```
2876
+ // If we want to fold %x into %b, the broadcasted dimensions from %x to
2877
+ // %b has to be the same as that of from %s to %b.
2878
+ ArrayRef<int64_t > shapecastShape =
2879
+ srcShapeCast.getResultVectorType ().getShape ();
2880
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
2881
+ ArrayRef<int64_t > destShape = destType.getShape ();
2882
+ BitVector origBroadcastedDims = getBroadcastedDims (shapecastShape, destShape);
2883
+ BitVector newBroadcastedDims = getBroadcastedDims (srcShape, destShape);
2884
+ if (newBroadcastedDims != origBroadcastedDims)
2885
+ return failure ();
2886
+
2887
+ broadcastOp.getSourceMutable ().assign (srcShapeCast.getSource ());
2888
+ return success ();
2889
+ }
2890
+
2844
2891
OpFoldResult BroadcastOp::fold (FoldAdaptor adaptor) {
2845
2892
if (getSourceType () == getResultVectorType ())
2846
2893
return getSource ();
2894
+ if (succeeded (foldBroadcastOfShapeCast (*this )))
2895
+ return getResult ();
2896
+
2847
2897
if (!adaptor.getSource ())
2848
2898
return {};
2849
2899
auto vectorType = getResultVectorType ();
@@ -2881,67 +2931,13 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2881
2931
return success ();
2882
2932
}
2883
2933
};
2884
-
2885
- // Return the broadcasted dimensions. Including broadcasts in the leading
2886
- // dimensions and broadcasts through unit dimension (i.e. dim-1).
2887
- static BitVector getBroadcastedDims (ArrayRef<int64_t > srcShape,
2888
- ArrayRef<int64_t > destShape) {
2889
- assert (destShape.size () >= srcShape.size ());
2890
- BitVector broadcastedDims (destShape.size ());
2891
- broadcastedDims.set (0 , destShape.size () - srcShape.size ());
2892
- auto unitDims = computeBroadcastedUnitDims (srcShape, destShape);
2893
- for (int64_t dim : unitDims)
2894
- broadcastedDims.set (dim);
2895
- return broadcastedDims;
2896
- }
2897
-
2898
- // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2899
- // with broadcast's result type and the broadcasted dimensions are the same.
2900
- struct FoldBroadcastOfShapeCast : public OpRewritePattern <BroadcastOp> {
2901
- using OpRewritePattern::OpRewritePattern;
2902
-
2903
- LogicalResult matchAndRewrite (BroadcastOp broadcastOp,
2904
- PatternRewriter &rewriter) const override {
2905
- auto srcShapeCast = broadcastOp.getSource ().getDefiningOp <ShapeCastOp>();
2906
- if (!srcShapeCast)
2907
- return failure ();
2908
-
2909
- VectorType srcType = srcShapeCast.getSourceVectorType ();
2910
- VectorType destType = broadcastOp.getResultVectorType ();
2911
- // Check type compatibility.
2912
- if (vector::isBroadcastableTo (srcType, destType) !=
2913
- BroadcastableToResult::Success)
2914
- return failure ();
2915
-
2916
- // Given
2917
- // ```
2918
- // %s = shape_cast(%x)
2919
- // %b = broadcast(%s)
2920
- // ```
2921
- // If we want to fold %x into %b, the broadcasted dimensions from %x to
2922
- // %b has to be the same as that of from %s to %b.
2923
- ArrayRef<int64_t > shapecastShape =
2924
- srcShapeCast.getResultVectorType ().getShape ();
2925
- ArrayRef<int64_t > srcShape = srcType.getShape ();
2926
- ArrayRef<int64_t > destShape = destType.getShape ();
2927
- BitVector origBroadcastedDims =
2928
- getBroadcastedDims (shapecastShape, destShape);
2929
- BitVector newBroadcastedDims = getBroadcastedDims (srcShape, destShape);
2930
- if (newBroadcastedDims != origBroadcastedDims)
2931
- return failure ();
2932
-
2933
- rewriter.replaceOpWithNewOp <BroadcastOp>(broadcastOp, destType,
2934
- srcShapeCast.getSource ());
2935
- return success ();
2936
- }
2937
- };
2938
2934
} // namespace
2939
2935
2940
2936
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
2941
2937
MLIRContext *context) {
2942
2938
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2943
2939
// calling `populateCastAwayVectorLeadingOneDimPatterns`
2944
- results.add <BroadcastFolder, FoldBroadcastOfShapeCast >(context);
2940
+ results.add <BroadcastFolder>(context);
2945
2941
}
2946
2942
2947
2943
// ===----------------------------------------------------------------------===//
0 commit comments