Skip to content

Commit b58b837

Browse files
committed
address review comments
1 parent 98daa18 commit b58b837

File tree

2 files changed

+59
-40
lines changed

2 files changed

+59
-40
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5776,10 +5776,12 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
57765776
}
57775777
};
57785778

5779-
/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5780-
/// This only applies when the shape of the broadcast source
5781-
/// 1. can be broadcast directly to the final shape, or
5782-
/// 2. has the same element count as the shape cast result.
5779+
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
5780+
///
5781+
/// 1) Y = ShapeCast(X), or
5782+
/// 2) Y = Broadcast(X)
5783+
///
5784+
/// If both (1) and (2) are possible, (1) is chosen.
57835785
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
57845786
public:
57855787
using OpRewritePattern::OpRewritePattern;
@@ -5791,22 +5793,17 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
57915793
if (!broadcastOp)
57925794
return failure();
57935795

5794-
{
5795-
VectorType dstType = shapeCastOp.getResultVectorType();
5796-
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
5797-
bool isScalar = !srcType;
5798-
if (isScalar || isBroadcastableTo(srcType, dstType) ==
5799-
BroadcastableToResult::Success) {
5800-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5801-
shapeCastOp, dstType, broadcastOp.getSource());
5802-
return success();
5803-
}
5804-
}
5805-
5806-
// If the final result has the same element count, we can replace with a
5807-
// shape cast.
5808-
if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5809-
if (srcType.getNumElements() ==
5796+
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
5797+
bool srcIsScalar = !srcVectorType;
5798+
5799+
// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
5800+
// Example:
5801+
// %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
5802+
// %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
5803+
// to
5804+
// %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
5805+
if (srcVectorType) {
5806+
if (srcVectorType.getNumElements() ==
58105807
shapeCastOp.getResultVectorType().getNumElements()) {
58115808
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
58125809
shapeCastOp, shapeCastOp.getResultVectorType(),
@@ -5815,6 +5812,19 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
58155812
}
58165813
}
58175814

5815+
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
5816+
// Example
5817+
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
5818+
// %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
5819+
// to
5820+
// %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
5821+
VectorType dstVectorType = shapeCastOp.getResultVectorType();
5822+
if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
5823+
BroadcastableToResult::Success) {
5824+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5825+
shapeCastOp, dstVectorType, broadcastOp.getSource());
5826+
return success();
5827+
}
58185828
return failure();
58195829
}
58205830
};

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,38 +1006,47 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) -
10061006

10071007
// -----
10081008

1009-
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
1010-
// CHECK-NOT: vector.broadcast
1011-
// CHECK: vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32>
1012-
func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> {
1013-
%0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32>
1014-
%1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32>
1015-
return %1 : vector<1x12xf32>
1009+
// CHECK-LABEL: func @canonicalize_broadcast_ones_shapecast_to_broadcast_ones
1010+
// CHECK: vector.broadcast {{.*}} vector<1x1xi8> to vector<1x1x6x1x4xi8>
1011+
// CHECK-NOT: vector.shape_cast
1012+
func.func @canonicalize_broadcast_ones_shapecast_to_broadcast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
1013+
%0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
1014+
%1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
1015+
return %1 : vector<1x1x6x1x4xi8>
10161016
}
10171017

10181018
// -----
10191019

1020-
1021-
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
1022-
// CHECK: vector.broadcast
1023-
// CHECK-SAME: f32 to vector<3x4x1xf32>
1020+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_scalar
1021+
// CHECK: vector.broadcast {{.*}} f32 to vector<3x4x1xf32>
10241022
// CHECK-NOT: vector.shape_cast
1025-
func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
1023+
func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
10261024
%0 = vector.broadcast %arg0 : f32 to vector<12xf32>
10271025
%1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32>
10281026
return %1 : vector<3x4x1xf32>
10291027
}
10301028

10311029
// -----
10321030

1033-
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_ones
1034-
// CHECK: vector.broadcast
1035-
// CHECK-SAME: vector<1x1xi8> to vector<1x1x6x1x4xi8>
1036-
// CHECK-NOT: vector.shape_cast
1037-
func.func @canonicalize_broadcast_shapecast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
1038-
%0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
1039-
%1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
1040-
return %1 : vector<1x1x6x1x4xi8>
1031+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
1032+
// CHECK-NOT: vector.broadcast
1033+
// CHECK: vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32>
1034+
func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> {
1035+
%0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32>
1036+
%1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32>
1037+
return %1 : vector<1x12xf32>
1038+
}
1039+
1040+
// -----
1041+
1042+
// In this test, it could be folded to broadcast or shape_cast, shape_cast is chosen.
1043+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast_priority
1044+
// CHECK-NOT: vector.broadcast
1045+
// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
1046+
func.func @canonicalize_broadcast_shapecast_to_shapcast_priority(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
1047+
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
1048+
%1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
1049+
return %1 : vector<1x1xf32>
10411050
}
10421051

10431052
// -----

0 commit comments

Comments
 (0)