Skip to content

Commit 409def2

Browse files
authored
[mlir][vector] shape_cast(broadcast) -> broadcast canonicalization (llvm#134939)
Add additional cases of this canonicalization, by checking the 'source of truth' function `isBroadcastableTo` to check when it is possible to broadcast directly to the shape resulting from the shape_cast. --------- Signed-off-by: James Newling <[email protected]>
1 parent f8f01b5 commit 409def2

File tree

2 files changed

+70
-32
lines changed

2 files changed

+70
-32
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5768,11 +5768,10 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
57685768
}
57695769
};
57705770

5771-
/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5772-
/// This only applies when the shape of the broadcast source
5773-
/// 1. is a suffix of the shape of the result (i.e. when broadcast without
5774-
/// reshape is expressive enough to capture the result in a single op), or
5775-
/// 2. has the same element count as the shape cast result.
5771+
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
5772+
/// i) Y = ShapeCast(X), or
5773+
/// ii) Y = Broadcast(X)
5774+
/// If both (i) and (ii) are possible, (i) is chosen.
57765775
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
57775776
public:
57785777
using OpRewritePattern::OpRewritePattern;
@@ -5784,26 +5783,17 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
57845783
if (!broadcastOp)
57855784
return failure();
57865785

5787-
ArrayRef<int64_t> broadcastSourceShape;
5788-
if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5789-
broadcastSourceShape = srcType.getShape();
5790-
ArrayRef<int64_t> shapeCastTargetShape =
5791-
shapeCastOp.getResultVectorType().getShape();
5792-
5793-
// If `broadcastSourceShape` is a suffix of the result, we can just replace
5794-
// with a broadcast to the final shape.
5795-
if (broadcastSourceShape ==
5796-
shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5797-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5798-
shapeCastOp, shapeCastOp.getResultVectorType(),
5799-
broadcastOp.getSource());
5800-
return success();
5801-
}
5802-
5803-
// Otherwise, if the final result has the same element count, we can replace
5804-
// with a shape cast.
5805-
if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5806-
if (srcType.getNumElements() ==
5786+
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
5787+
bool srcIsScalar = !srcVectorType;
5788+
5789+
// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
5790+
// Example:
5791+
// %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
5792+
// %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
5793+
// to
5794+
// %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
5795+
if (srcVectorType) {
5796+
if (srcVectorType.getNumElements() ==
58075797
shapeCastOp.getResultVectorType().getNumElements()) {
58085798
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
58095799
shapeCastOp, shapeCastOp.getResultVectorType(),
@@ -5812,6 +5802,19 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
58125802
}
58135803
}
58145804

5805+
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
5806+
// Example
5807+
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
5808+
// %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
5809+
// to
5810+
// %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
5811+
VectorType dstVectorType = shapeCastOp.getResultVectorType();
5812+
if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
5813+
BroadcastableToResult::Success) {
5814+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5815+
shapeCastOp, dstVectorType, broadcastOp.getSource());
5816+
return success();
5817+
}
58155818
return failure();
58165819
}
58175820
};
@@ -6072,7 +6075,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
60726075
}
60736076
};
60746077

6075-
// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
6078+
// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
60766079
struct FoldTransposedScalarBroadcast final
60776080
: public OpRewritePattern<vector::TransposeOp> {
60786081
using OpRewritePattern::OpRewritePattern;

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,13 +1006,48 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) -
10061006

10071007
// -----
10081008

1009-
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
1009+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_ones
1010+
// CHECK: vector.broadcast {{.*}} vector<1x1xi8> to vector<1x1x6x1x4xi8>
1011+
// CHECK-NOT: vector.shape_cast
1012+
func.func @canonicalize_broadcast_shapecast_to_broadcast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
1013+
%0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
1014+
%1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
1015+
return %1 : vector<1x1x6x1x4xi8>
1016+
}
1017+
1018+
// -----
1019+
1020+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_scalar
1021+
// CHECK: vector.broadcast {{.*}} f32 to vector<3x4x1xf32>
1022+
// CHECK-NOT: vector.shape_cast
1023+
func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
1024+
%0 = vector.broadcast %arg0 : f32 to vector<12xf32>
1025+
%1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32>
1026+
return %1 : vector<3x4x1xf32>
1027+
}
1028+
1029+
// -----
1030+
1031+
// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
1032+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
1033+
// CHECK-NOT: vector.broadcast
1034+
// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
1035+
func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
1036+
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
1037+
%1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
1038+
return %1 : vector<1x2x1xf32>
1039+
}
1040+
1041+
// -----
1042+
1043+
// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
1044+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
10101045
// CHECK-NOT: vector.broadcast
1011-
// CHECK: vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32>
1012-
func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> {
1013-
%0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32>
1014-
%1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32>
1015-
return %1 : vector<1x12xf32>
1046+
// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
1047+
func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
1048+
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
1049+
%1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
1050+
return %1 : vector<1x1xf32>
10161051
}
10171052

10181053
// -----

0 commit comments

Comments
 (0)