Skip to content
25 changes: 24 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2938,13 +2938,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};

// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
// with broadcast's result type.
struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
if (!srcShapeCast)
return failure();

VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) !=
BroadcastableToResult::Success)
return failure();

rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();
}
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a folder, not a rewrite pattern.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just rewrote it into a folder

} // namespace

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
results.add<BroadcastFolder>(context);
results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)

// -----

// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
// CHECK-NOT: vector.shape_cast
// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
return %1 : vector<32x2xf32>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about dim-1 broadcasting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more test cases for that.


// -----

// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
%0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
%1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
return %1 : vector<2x4x16xf32>
}

// -----

// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
Expand Down