Skip to content

Commit d5d59c2

Browse files
committed
cover additional cases of shape_cast(broadcast) -> broadcast canonicalization
Signed-off-by: James Newling <[email protected]>
1 parent 1d0f835 commit d5d59c2

File tree

2 files changed

+39
-19
lines changed

2 files changed

+39
-19
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5778,8 +5778,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
57785778

57795779
/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
57805780
/// This only applies when the shape of the broadcast source
5781-
/// 1. is a suffix of the shape of the result (i.e. when broadcast without
5782-
/// reshape is expressive enough to capture the result in a single op), or
5781+
/// 1. can be broadcast directly to the final shape, or
57835782
/// 2. has the same element count as the shape cast result.
57845783
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
57855784
public:
@@ -5792,24 +5791,20 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
57925791
if (!broadcastOp)
57935792
return failure();
57945793

5795-
ArrayRef<int64_t> broadcastSourceShape;
5796-
if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5797-
broadcastSourceShape = srcType.getShape();
5798-
ArrayRef<int64_t> shapeCastTargetShape =
5799-
shapeCastOp.getResultVectorType().getShape();
5800-
5801-
// If `broadcastSourceShape` is a suffix of the result, we can just replace
5802-
// with a broadcast to the final shape.
5803-
if (broadcastSourceShape ==
5804-
shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5805-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5806-
shapeCastOp, shapeCastOp.getResultVectorType(),
5807-
broadcastOp.getSource());
5808-
return success();
5794+
{
5795+
VectorType dstType = shapeCastOp.getResultVectorType();
5796+
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
5797+
bool isScalar = !srcType;
5798+
if (isScalar || isBroadcastableTo(srcType, dstType) ==
5799+
BroadcastableToResult::Success) {
5800+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5801+
shapeCastOp, dstType, broadcastOp.getSource());
5802+
return success();
5803+
}
58095804
}
58105805

5811-
// Otherwise, if the final result has the same element count, we can replace
5812-
// with a shape cast.
5806+
// If the final result has the same element count, we can replace with a
5807+
// shape cast.
58135808
if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
58145809
if (srcType.getNumElements() ==
58155810
shapeCastOp.getResultVectorType().getNumElements()) {
@@ -6079,7 +6074,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
60796074
}
60806075
};
60816076

6082-
// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
6077+
// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
60836078
struct FoldTransposedScalarBroadcast final
60846079
: public OpRewritePattern<vector::TransposeOp> {
60856080
using OpRewritePattern::OpRewritePattern;

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,31 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
10171017

10181018
// -----
10191019

1020+
1021+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
1022+
// CHECK: vector.broadcast
1023+
// CHECK-SAME: f32 to vector<3x4x1xf32>
1024+
// CHECK-NOT: vector.shape_cast
1025+
func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
1026+
%0 = vector.broadcast %arg0 : f32 to vector<12xf32>
1027+
%1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32>
1028+
return %1 : vector<3x4x1xf32>
1029+
}
1030+
1031+
// -----
1032+
1033+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_ones
1034+
// CHECK: vector.broadcast
1035+
// CHECK-SAME: vector<1x1xi8> to vector<1x1x6x1x4xi8>
1036+
// CHECK-NOT: vector.shape_cast
1037+
func.func @canonicalize_broadcast_shapecast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
1038+
%0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
1039+
%1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
1040+
return %1 : vector<1x1x6x1x4xi8>
1041+
}
1042+
1043+
// -----
1044+
10201045
// CHECK-LABEL: fold_vector_transfer_masks
10211046
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
10221047
// CHECK: %[[C0:.+]] = arith.constant 0 : index

0 commit comments

Comments
 (0)