Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 29 additions & 26 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5776,11 +5776,10 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};

/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
/// This only applies when the shape of the broadcast source
/// 1. is a suffix of the shape of the result (i.e. when broadcast without
/// reshape is expressive enough to capture the result in a single op), or
/// 2. has the same element count as the shape cast result.
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
/// i) Y = ShapeCast(X), or
/// ii) Y = Broadcast(X)
/// If both (i) and (ii) are possible, (i) is chosen.
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
Expand All @@ -5792,26 +5791,17 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
if (!broadcastOp)
return failure();

ArrayRef<int64_t> broadcastSourceShape;
if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
broadcastSourceShape = srcType.getShape();
ArrayRef<int64_t> shapeCastTargetShape =
shapeCastOp.getResultVectorType().getShape();

// If `broadcastSourceShape` is a suffix of the result, we can just replace
// with a broadcast to the final shape.
if (broadcastSourceShape ==
shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
shapeCastOp, shapeCastOp.getResultVectorType(),
broadcastOp.getSource());
return success();
}

// Otherwise, if the final result has the same element count, we can replace
// with a shape cast.
if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
if (srcType.getNumElements() ==
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;

// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
// Example:
// %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
// %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
// to
// %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
if (srcVectorType) {
if (srcVectorType.getNumElements() ==
shapeCastOp.getResultVectorType().getNumElements()) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
shapeCastOp, shapeCastOp.getResultVectorType(),
Comment on lines +5794 to 5807
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ok to land for now, but this should be a folder not a canonicalization pattern.

Expand All @@ -5820,6 +5810,19 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
}
}

// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
// Example
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
// %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
// to
// %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
VectorType dstVectorType = shapeCastOp.getResultVectorType();
if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
BroadcastableToResult::Success) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
shapeCastOp, dstVectorType, broadcastOp.getSource());
return success();
}
return failure();
}
};
Expand Down Expand Up @@ -6079,7 +6082,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
}
};

// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
struct FoldTransposedScalarBroadcast final
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down
47 changes: 41 additions & 6 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1006,13 +1006,48 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) -

// -----

// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_ones
// CHECK: vector.broadcast {{.*}} vector<1x1xi8> to vector<1x1x6x1x4xi8>
// CHECK-NOT: vector.shape_cast
func.func @canonicalize_broadcast_shapecast_to_broadcast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
%1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
return %1 : vector<1x1x6x1x4xi8>
}

// -----

// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_scalar
// CHECK: vector.broadcast {{.*}} f32 to vector<3x4x1xf32>
// CHECK-NOT: vector.shape_cast
func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<12xf32>
%1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32>
return %1 : vector<3x4x1xf32>
}

// -----

// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
// CHECK-NOT: vector.broadcast
// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
%1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
return %1 : vector<1x2x1xf32>
}

// -----

// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
// CHECK-NOT: vector.broadcast
// CHECK: vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32>
func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> {
%0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32>
%1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32>
return %1 : vector<1x12xf32>
// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
%1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
return %1 : vector<1x1xf32>
}

// -----
Expand Down