-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][vector] shape_cast(broadcast) -> broadcast canonicalization #134939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] shape_cast(broadcast) -> broadcast canonicalization #134939
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesAdd additional cases of this canonicalization, by checking the 'source of truth' function Full diff: https://github.com/llvm/llvm-project/pull/134939.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..c6d8ec1e1cf69 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5778,8 +5778,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
/// This only applies when the shape of the broadcast source
-/// 1. is a suffix of the shape of the result (i.e. when broadcast without
-/// reshape is expressive enough to capture the result in a single op), or
+/// 1. can be broadcast directly to the final shape, or
/// 2. has the same element count as the shape cast result.
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
@@ -5792,24 +5791,20 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
if (!broadcastOp)
return failure();
- ArrayRef<int64_t> broadcastSourceShape;
- if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
- broadcastSourceShape = srcType.getShape();
- ArrayRef<int64_t> shapeCastTargetShape =
- shapeCastOp.getResultVectorType().getShape();
-
- // If `broadcastSourceShape` is a suffix of the result, we can just replace
- // with a broadcast to the final shape.
- if (broadcastSourceShape ==
- shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- shapeCastOp, shapeCastOp.getResultVectorType(),
- broadcastOp.getSource());
- return success();
+ {
+ VectorType dstType = shapeCastOp.getResultVectorType();
+ auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ bool isScalar = !srcType;
+ if (isScalar || isBroadcastableTo(srcType, dstType) ==
+ BroadcastableToResult::Success) {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ shapeCastOp, dstType, broadcastOp.getSource());
+ return success();
+ }
}
- // Otherwise, if the final result has the same element count, we can replace
- // with a shape cast.
+ // If the final result has the same element count, we can replace with a
+ // shape cast.
if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
if (srcType.getNumElements() ==
shapeCastOp.getResultVectorType().getNumElements()) {
@@ -6079,7 +6074,7 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
}
};
-// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
+// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
struct FoldTransposedScalarBroadcast final
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..d7617d79b5cbf 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1017,6 +1017,31 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
// -----
+
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
+// CHECK: vector.broadcast
+// CHECK-SAME: f32 to vector<3x4x1xf32>
+// CHECK-NOT: vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<3x4x1xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<12xf32>
+ %1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32>
+ return %1 : vector<3x4x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_ones
+// CHECK: vector.broadcast
+// CHECK-SAME: vector<1x1xi8> to vector<1x1x6x1x4xi8>
+// CHECK-NOT: vector.shape_cast
+func.func @canonicalize_broadcast_shapecast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8>
+ %1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8>
+ return %1 : vector<1x1x6x1x4xi8>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
d5d59c2 to
b58b837
Compare
|
Thanks for your suggested improvements @dcaballe and @banach-space. I've hopefully addressed them all |
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great clean-up, thank you!
I've left one [nit], but that's non-blocking so approving as is. LGTM
…lization Signed-off-by: James Newling <[email protected]>
Signed-off-by: James Newling <[email protected]>
151bbfd to
408842c
Compare
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, thanks!
| auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType()); | ||
| bool srcIsScalar = !srcVectorType; | ||
|
|
||
| // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X). | ||
| // Example: | ||
| // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32> | ||
| // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32> | ||
| // to | ||
| // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32> | ||
| if (srcVectorType) { | ||
| if (srcVectorType.getNumElements() == | ||
| shapeCastOp.getResultVectorType().getNumElements()) { | ||
| rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( | ||
| shapeCastOp, shapeCastOp.getResultVectorType(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's ok to land for now, but this should be a folder not a canonicalization pattern.
Add additional cases of this canonicalization, by checking the 'source of truth' function
isBroadcastableToto check when it is possible to broadcast directly to the shape resulting from the shape_cast.