Skip to content

[mlir][vector] Canonicalize broadcast of shape_cast #150523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 8, 2025
25 changes: 24 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2938,13 +2938,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};

// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
// with broadcast's result type.
struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
if (!srcShapeCast)
return failure();

VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) !=
BroadcastableToResult::Success)
return failure();

rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();
}
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a folder, not a rewrite pattern.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just rewrote it into a folder

} // namespace

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
results.add<BroadcastFolder>(context);
results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)

// -----

// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
// CHECK-NOT: vector.shape_cast
// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
return %1 : vector<32x2xf32>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about dim-1 broadcasting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more test cases for that.


// -----

// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
%0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
%1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
return %1 : vector<2x4x16xf32>
}

// -----

// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
Expand Down