@@ -2841,9 +2841,59 @@ LogicalResult BroadcastOp::verify() {
28412841 llvm_unreachable (" unexpected vector.broadcast op error" );
28422842}
28432843
2844+ // Return the broadcasted dimensions. Including broadcasts in the leading
2845+ // dimensions and broadcasts through unit dimension (i.e. dim-1).
2846+ static BitVector getBroadcastedDims (ArrayRef<int64_t > srcShape,
2847+ ArrayRef<int64_t > destShape) {
2848+ assert (destShape.size () >= srcShape.size ());
2849+ BitVector broadcastedDims (destShape.size ());
2850+ broadcastedDims.set (0 , destShape.size () - srcShape.size ());
2851+ auto unitDims = computeBroadcastedUnitDims (srcShape, destShape);
2852+ for (int64_t dim : unitDims)
2853+ broadcastedDims.set (dim);
2854+ return broadcastedDims;
2855+ }
2856+
2857+ // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2858+ // with broadcast's result type and the broadcasted dimensions are the same.
2859+ static LogicalResult foldBroadcastOfShapeCast (BroadcastOp broadcastOp) {
2860+ auto srcShapeCast = broadcastOp.getSource ().getDefiningOp <ShapeCastOp>();
2861+ if (!srcShapeCast)
2862+ return failure ();
2863+
2864+ VectorType srcType = srcShapeCast.getSourceVectorType ();
2865+ VectorType destType = broadcastOp.getResultVectorType ();
2866+ // Check type compatibility.
2867+ if (vector::isBroadcastableTo (srcType, destType) !=
2868+ BroadcastableToResult::Success)
2869+ return failure ();
2870+
2871+ // Given
2872+ // ```
2873+ // %s = shape_cast(%x)
2874+ // %b = broadcast(%s)
2875+ // ```
2876+ // If we want to fold %x into %b, the broadcasted dimensions from %x to
2877+ // %b has to be the same as that of from %s to %b.
2878+ ArrayRef<int64_t > shapecastShape =
2879+ srcShapeCast.getResultVectorType ().getShape ();
2880+ ArrayRef<int64_t > srcShape = srcType.getShape ();
2881+ ArrayRef<int64_t > destShape = destType.getShape ();
2882+ BitVector origBroadcastedDims = getBroadcastedDims (shapecastShape, destShape);
2883+ BitVector newBroadcastedDims = getBroadcastedDims (srcShape, destShape);
2884+ if (newBroadcastedDims != origBroadcastedDims)
2885+ return failure ();
2886+
2887+ broadcastOp.getSourceMutable ().assign (srcShapeCast.getSource ());
2888+ return success ();
2889+ }
2890+
28442891OpFoldResult BroadcastOp::fold (FoldAdaptor adaptor) {
28452892 if (getSourceType () == getResultVectorType ())
28462893 return getSource ();
2894+ if (succeeded (foldBroadcastOfShapeCast (*this )))
2895+ return getResult ();
2896+
28472897 if (!adaptor.getSource ())
28482898 return {};
28492899 auto vectorType = getResultVectorType ();
@@ -2881,67 +2931,13 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
28812931 return success ();
28822932 }
28832933};
2884-
2885- // Return the broadcasted dimensions. Including broadcasts in the leading
2886- // dimensions and broadcasts through unit dimension (i.e. dim-1).
2887- static BitVector getBroadcastedDims (ArrayRef<int64_t > srcShape,
2888- ArrayRef<int64_t > destShape) {
2889- assert (destShape.size () >= srcShape.size ());
2890- BitVector broadcastedDims (destShape.size ());
2891- broadcastedDims.set (0 , destShape.size () - srcShape.size ());
2892- auto unitDims = computeBroadcastedUnitDims (srcShape, destShape);
2893- for (int64_t dim : unitDims)
2894- broadcastedDims.set (dim);
2895- return broadcastedDims;
2896- }
2897-
2898- // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
2899- // with broadcast's result type and the broadcasted dimensions are the same.
2900- struct FoldBroadcastOfShapeCast : public OpRewritePattern <BroadcastOp> {
2901- using OpRewritePattern::OpRewritePattern;
2902-
2903- LogicalResult matchAndRewrite (BroadcastOp broadcastOp,
2904- PatternRewriter &rewriter) const override {
2905- auto srcShapeCast = broadcastOp.getSource ().getDefiningOp <ShapeCastOp>();
2906- if (!srcShapeCast)
2907- return failure ();
2908-
2909- VectorType srcType = srcShapeCast.getSourceVectorType ();
2910- VectorType destType = broadcastOp.getResultVectorType ();
2911- // Check type compatibility.
2912- if (vector::isBroadcastableTo (srcType, destType) !=
2913- BroadcastableToResult::Success)
2914- return failure ();
2915-
2916- // Given
2917- // ```
2918- // %s = shape_cast(%x)
2919- // %b = broadcast(%s)
2920- // ```
2921- // If we want to fold %x into %b, the broadcasted dimensions from %x to
2922- // %b has to be the same as that of from %s to %b.
2923- ArrayRef<int64_t > shapecastShape =
2924- srcShapeCast.getResultVectorType ().getShape ();
2925- ArrayRef<int64_t > srcShape = srcType.getShape ();
2926- ArrayRef<int64_t > destShape = destType.getShape ();
2927- BitVector origBroadcastedDims =
2928- getBroadcastedDims (shapecastShape, destShape);
2929- BitVector newBroadcastedDims = getBroadcastedDims (srcShape, destShape);
2930- if (newBroadcastedDims != origBroadcastedDims)
2931- return failure ();
2932-
2933- rewriter.replaceOpWithNewOp <BroadcastOp>(broadcastOp, destType,
2934- srcShapeCast.getSource ());
2935- return success ();
2936- }
2937- };
29382934} // namespace
29392935
29402936void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
29412937 MLIRContext *context) {
29422938 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
29432939 // calling `populateCastAwayVectorLeadingOneDimPatterns`
2944- results.add <BroadcastFolder, FoldBroadcastOfShapeCast >(context);
2940+ results.add <BroadcastFolder>(context);
29452941}
29462942
29472943// ===----------------------------------------------------------------------===//
0 commit comments