-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][vector] Canonicalize broadcast of shape_cast #150523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Canonicalize broadcast of shape_cast #150523
Conversation
Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is compatible with broadcast's result type.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Min-Yih Hsu (mshockwave) ChangesFold Full diff: https://github.com/llvm/llvm-project/pull/150523.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed6e7742..ad908319d8584 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ if (auto srcShapeCast =
+ broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ if (vector::isBroadcastableTo(srcType, destType) ==
+ BroadcastableToResult::Success) {
+ rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+ srcShapeCast.getSource());
+ return success();
+ }
+ }
+ return failure();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30162c5f..0fd2acd06c8ec 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
// -----
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+ %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
+// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
+// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+ %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+ return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, makes sense % minor suggestions.
if (auto srcShapeCast = | ||
broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) { | ||
VectorType srcType = srcShapeCast.getSourceVectorType(); | ||
VectorType destType = broadcastOp.getResultVectorType(); | ||
if (vector::isBroadcastableTo(srcType, destType) == | ||
BroadcastableToResult::Success) { | ||
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType, | ||
srcShapeCast.getSource()); | ||
return success(); | ||
} | ||
} | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Prefer early exits - helps reduce indentation.
if (auto srcShapeCast = | |
broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) { | |
VectorType srcType = srcShapeCast.getSourceVectorType(); | |
VectorType destType = broadcastOp.getResultVectorType(); | |
if (vector::isBroadcastableTo(srcType, destType) == | |
BroadcastableToResult::Success) { | |
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType, | |
srcShapeCast.getSource()); | |
return success(); | |
} | |
} | |
return failure(); | |
auto srcShapeCast = | |
broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); | |
if (!srcShapeCast) | |
return failure(); | |
VectorType srcType = srcShapeCast.getSourceVectorType(); | |
VectorType destType = broadcastOp.getResultVectorType(); | |
if (vector::isBroadcastableTo(srcType, destType) != | |
BroadcastableToResult::Success) | |
return failure(); | |
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType, | |
srcShapeCast.getSource()); | |
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
Co-authored-by: Andrzej Warzyński <[email protected]>
Can you reuse isBroadcastLike
General rule is that is something can be a folder, it should be (i.e. on BroadcastOp::fold) https://mlir.llvm.org/docs/Canonicalization/#when-to-use-the-fold-method-vs-rewriterpatterns-for-canonicalizations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is always valid?
(2,1) -> shape_cast -> (1,2) -> broadcast (2,2)
and
(2,1) -> broadcast (2,2)
are different.
Example. If input is [[5], [6]]. then first one's output is [[5, 6], [5, 6]] but second one's is [[5, 5], [6, 6]].
Yeah you're right. Let me turn this PR draft and think about this. |
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible | ||
// with broadcast's result type. | ||
struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(BroadcastOp broadcastOp, | ||
PatternRewriter &rewriter) const override { | ||
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); | ||
if (!srcShapeCast) | ||
return failure(); | ||
|
||
VectorType srcType = srcShapeCast.getSourceVectorType(); | ||
VectorType destType = broadcastOp.getResultVectorType(); | ||
if (vector::isBroadcastableTo(srcType, destType) != | ||
BroadcastableToResult::Success) | ||
return failure(); | ||
|
||
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType, | ||
srcShapeCast.getSource()); | ||
return success(); | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a folder, not a rewrite pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just rewrote it into a folder
I updated the algorithm to add a condition that the replicating dimensions have to be the same before and after the transformations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I think this is correct now. But added a suggestion which might simplify it.
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible | ||
// with broadcast's result type and the broadcasted dimensions are the same. | ||
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) { | ||
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the same as saying (where srcShape -> shapecastShape -> destShape)
- rank(srcShape) <= rank(destShape)
- srcShape and shapeCastShape are the same, except that one has some 1's prepended. i.e. where R = min(srcShape.rank, shapeCastShape.rank), last R dimensions of srcShape and shapeCastCast are the same.
If so, would be more intuitive I think. If not, can you please provided a counterexample?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can roughly breakdown this into five cases by how we shape_cast
(1) srcShape is "broken" up into multiple non-one dimensions. e.g. <4x1> -> <2x2>
(2) srcShape is prepended by one or more ones
(3) srcShape is appended by one or more ones
(4) One or more leading dimensions in srcShape were removed
(5) One or more trailing dimensions in srcShape were removed
Note that multiple cases could be applied at the same time. For instance <2x1> -> <1x2> is removing the trailing dimension before appending a new one.
Case (1) is easy: srcShape will never be broadcastable w.r.t destShape. Because the rule of broadcast effectively mandates the source dimensions to be a "subset" of destination dimensions, modulo dimensions that are one. And changing the dimension values will violate that.
I think case (2), (4) are conjugate. Because broadcasting at those prepended dimensions that are one is the same as broadcasting toward missing (leading) dimensions; similarly, broadcasting at missing leading dimensions is the same as broadcasting at ones that were once there. Therefore, they are allowed.
Case (3) and (5) are similar, both of them change the "neighboring" elements in the highest dimension -- an element either becomes or not become 'singleton'. For instance [A, B] turns into [[A], [B]] when we cast from <2> to <2x1>. In which case element A turn from having a neighbor B into singleton. Whether it's singleton or not is important, because an element that is not singleton will always be broadcasted with its neighbor. On the other hand, being singleton means that it could be replicated on its own. Since this alters the broadcasting behavior, once this appears -- even combined with other cases like <1x2> -> <2x1> mentioned earlier -- we could not do the folding. Note that this also coincides with my current rule -- the original replicated dimensions have to match with the new replicated dimensions.
The bottom line is: I think your new rule is correct, I'm gonna update to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The algorithm is now updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the analysis, looks good to me as does the new impl. I think rank(srcShape) <= rank(destShape)
is sufficient, but actually the way you check with isBroadcastableTo
will probably be more intuitive to future readers.
func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> { | ||
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32> | ||
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32> | ||
return %1 : vector<32x2xf32> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about dim-1
broadcasting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added more test cases for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the comprehensive testing.
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible | ||
// with broadcast's result type and the broadcasted dimensions are the same. | ||
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) { | ||
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the analysis, looks good to me as does the new impl. I think rank(srcShape) <= rank(destShape)
is sufficient, but actually the way you check with isBroadcastableTo
will probably be more intuitive to future readers.
assert(all_of(srcShape.drop_back(numTrailingDims), | ||
[](int64_t E) { return E == 1; }) && | ||
all_of(shapecastShape.drop_back(numTrailingDims), | ||
[](int64_t E) { return E == 1; }) && | ||
"ill-formed shape_cast"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Unlike LLVM, we use camelCasel
in MLIR for variable names. So, E
-> e
(rather confusing, I know). If you want to avoid e
(less readable than E
IMHO), you could try E
-> dim
🤷🏻
Fold
broadcast(shape_cast(x))
intobroadcast(x)
if the type of x is compatible with broadcast's result type.