-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir] Allow folding dynamic full size subviews #140619
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Max Dawkins <[email protected]>
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: None (Max191) ChangesSupports folding subviews with dynamic sizes in Full diff: https://github.com/llvm/llvm-project/pull/140619.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 82702789c2913..3ccdcbf8c3be5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3105,7 +3105,7 @@ FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
/// is the case if the all offsets are zero, all strides are 1, and the source
/// shape is same as the size of the subview. In such cases, the subview can
/// be folded into its source.
-static bool isTrivialSubViewOp(SubViewOp subViewOp) {
+static bool isTrivialSubViewOp(OpBuilder &b, SubViewOp subViewOp) {
if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
return false;
@@ -3127,15 +3127,24 @@ static bool isTrivialSubViewOp(SubViewOp subViewOp) {
}))
return false;
- // Check all size values are static and matches the (static) source shape.
+ // Check all size values match the source shape.
ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
- for (const auto &size : llvm::enumerate(mixedSizes)) {
- std::optional<int64_t> intValue = getConstantIntValue(size.value());
- if (!intValue || *intValue != sourceShape[size.index()])
- return false;
+ if (llvm::all_of_zip(mixedSizes, sourceShape,
+ [](OpFoldResult mixedSize, int64_t staticSize) {
+ std::optional<int64_t> constSize =
+ getConstantIntValue(mixedSize);
+ return constSize.has_value() &&
+ *constSize == staticSize;
+ })) {
+ return true;
}
- // All conditions met. The `SubViewOp` is foldable as a no-op.
- return true;
+ auto sourceOpResult = dyn_cast<OpResult>(subViewOp.getSource());
+ if (!sourceOpResult)
+ return false;
+ ReifiedRankedShapedTypeDims resultDims;
+ if (failed(reifyResultShapes(b, sourceOpResult.getOwner(), resultDims)))
+ return false;
+ return llvm::equal(mixedSizes, resultDims[sourceOpResult.getResultNumber()]);
}
namespace {
@@ -3206,7 +3215,7 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
LogicalResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
- if (!isTrivialSubViewOp(subViewOp))
+ if (!isTrivialSubViewOp(rewriter, subViewOp))
return failure();
if (subViewOp.getSourceType() == subViewOp.getType()) {
rewriter.replaceOp(subViewOp, subViewOp.getSource());
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e7cee7cd85426..ebad9e3eab345 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -70,6 +70,20 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
// -----
+// CHECK-LABEL: func @subview_of_dynamic_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<?xi8>
+// CHECK-SAME: %[[SIZE:.+]]: index
+// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape
+// CHECK-NOT: memref.subview
+// CHECK: return %[[EXPAND_SHAPE]] : memref<?x?xi8>
+func.func @subview_of_dynamic_full_size(%arg0 : memref<?xi8>, %size : index) -> memref<?x?xi8> {
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [%size, %size] : memref<?xi8> into memref<?x?xi8>
+ %1 = memref.subview %0[0, 0] [%size, %size] [1, 1] : memref<?x?xi8> to memref<?x?xi8>
+ return %1 : memref<?x?xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @negative_subview_of_static_full_size
// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>>
// CHECK-SAME: %[[IDX:.+]]: index
|
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks okay to me, but please wait for a day, just in case if other reviewers want to take a look.
I'm not pretty sure if it should be a canonicalization pattern or not because reifyResultShapes could create operations. It depends on the implementation details of the source op.
| std::optional<int64_t> constSize = | ||
| getConstantIntValue(mixedSize); | ||
| return constSize.has_value() && | ||
| *constSize == staticSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can use isConstantIntValue.
llvm-project/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Lines 112 to 113 in a0c515a
| /// Return true if `ofr` is constant integer equal to `value`. | |
| bool isConstantIntValue(OpFoldResult ofr, int64_t value); |
| /// Helper method to check if a `subview` operation is trivially a no-op. This | ||
| /// is the case if the all offsets are zero, all strides are 1, and the source | ||
| /// shape is same as the size of the subview. In such cases, the subview can | ||
| /// be folded into its source. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also update the comment?
| PatternRewriter &rewriter) const override { | ||
| if (!isTrivialSubViewOp(subViewOp)) | ||
| if (!isTrivialSubViewOp(rewriter, subViewOp)) | ||
| return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is problematic: isTrivialSubViewOp creates new IR, but then you return "failure". This is not allowed in a rewrite pattern. You have to make sure that there is no change to the input IR when you return "failure".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it is difficult to compare the dynamic sizes without doing this. The other option I considered was using the ValueBoundsOpInterface, but that seems like an expensive check to run in a canonicalization. It would be nice if there were a version of reifyResultShapes that did not create IR, and returned failure if creating new IR was necessary, but that seems like a lot of work to plumb through.
These are the ways I know of to compare the result sizes of a tensor, but if you have any better suggestions, then that would be very helpful!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After looking into it further, I do see some instances of ValueBoundsOpInterface being used in canonicalization patterns, so I suppose it is the lesser of the two evils. I will update the PR implementation to use it instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, this won't work because I don't think it can capture equality of fully dynamic sizes that come from the same SSA value. It is more about computing bounds. So I am back to looking for something like reifyResultShapes, but without creating extra IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't use ValueBoundsOpInterface during canonicalization. It is quite expensive and should probably removed from the canonicalization patterns that use it today.
I don't have a good solution. Maybe we can build a new pass that folds various index computations and view-like ops based on a single ValueBoundsOpInterface analysis. Maybe that pass could then also replace the "reify ranked shaped ..." pass. Not sure if it's a good idea, just trying to think of a solution...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have an interface that does what I need in downstream IREE, actually. It is like reifyResultShapes, but "read only", as in, it doesn't create new IR. I am going to close this PR and handle this downstream for the time being, since I don't see a great way to do this here in MLIR.
Supports folding subviews with dynamic sizes in
TrivialSubViewOpFolderusing theReifyRankedShapedTypeOpInterface.