Skip to content

Conversation

@Max191
Copy link
Contributor

@Max191 Max191 commented May 19, 2025

Supports folding subviews with dynamic sizes in TrivialSubViewOpFolder using the ReifyRankedShapedTypeOpInterface.

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: None (Max191)

Changes

Supports folding subviews with dynamic sizes in TrivialSubViewOpFolder using the ReifyRankedShapedTypeOpInterface.


Full diff: https://github.com/llvm/llvm-project/pull/140619.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+18-9)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+14)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 82702789c2913..3ccdcbf8c3be5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3105,7 +3105,7 @@ FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
 /// is the case if the all offsets are zero, all strides are 1, and the source
 /// shape is same as the size of the subview. In such cases, the subview can
 /// be folded into its source.
-static bool isTrivialSubViewOp(SubViewOp subViewOp) {
+static bool isTrivialSubViewOp(OpBuilder &b, SubViewOp subViewOp) {
   if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
     return false;
 
@@ -3127,15 +3127,24 @@ static bool isTrivialSubViewOp(SubViewOp subViewOp) {
       }))
     return false;
 
-  // Check all size values are static and matches the (static) source shape.
+  // Check all size values match the source shape.
   ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
-  for (const auto &size : llvm::enumerate(mixedSizes)) {
-    std::optional<int64_t> intValue = getConstantIntValue(size.value());
-    if (!intValue || *intValue != sourceShape[size.index()])
-      return false;
+  if (llvm::all_of_zip(mixedSizes, sourceShape,
+                       [](OpFoldResult mixedSize, int64_t staticSize) {
+                         std::optional<int64_t> constSize =
+                             getConstantIntValue(mixedSize);
+                         return constSize.has_value() &&
+                                *constSize == staticSize;
+                       })) {
+    return true;
   }
-  // All conditions met. The `SubViewOp` is foldable as a no-op.
-  return true;
+  auto sourceOpResult = dyn_cast<OpResult>(subViewOp.getSource());
+  if (!sourceOpResult)
+    return false;
+  ReifiedRankedShapedTypeDims resultDims;
+  if (failed(reifyResultShapes(b, sourceOpResult.getOwner(), resultDims)))
+    return false;
+  return llvm::equal(mixedSizes, resultDims[sourceOpResult.getResultNumber()]);
 }
 
 namespace {
@@ -3206,7 +3215,7 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
 
   LogicalResult matchAndRewrite(SubViewOp subViewOp,
                                 PatternRewriter &rewriter) const override {
-    if (!isTrivialSubViewOp(subViewOp))
+    if (!isTrivialSubViewOp(rewriter, subViewOp))
       return failure();
     if (subViewOp.getSourceType() == subViewOp.getType()) {
       rewriter.replaceOp(subViewOp, subViewOp.getSource());
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e7cee7cd85426..ebad9e3eab345 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -70,6 +70,20 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
 
 // -----
 
+// CHECK-LABEL: func @subview_of_dynamic_full_size
+//  CHECK-SAME:   %[[ARG0:.+]]: memref<?xi8>
+//  CHECK-SAME:   %[[SIZE:.+]]: index
+//       CHECK:   %[[EXPAND_SHAPE:.+]] = memref.expand_shape
+//   CHECK-NOT:   memref.subview
+//       CHECK:   return %[[EXPAND_SHAPE]] : memref<?x?xi8>
+func.func @subview_of_dynamic_full_size(%arg0 : memref<?xi8>, %size : index) -> memref<?x?xi8> {
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [%size, %size] : memref<?xi8> into memref<?x?xi8>
+  %1 = memref.subview %0[0, 0] [%size, %size] [1, 1] : memref<?x?xi8> to memref<?x?xi8>
+  return %1 : memref<?x?xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @negative_subview_of_static_full_size
 //  CHECK-SAME:   %[[ARG0:.+]]: memref<16x4xf32,  strided<[4, 1], offset: ?>>
 //  CHECK-SAME:   %[[IDX:.+]]: index

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks okay to me, but please wait for a day, just in case if other reviewers want to take a look.

I'm not pretty sure if it should be a canonicalization pattern or not because reifyResultShapes could create operations. It depends on the implementation details of the source op.

Comment on lines +3134 to +3137
std::optional<int64_t> constSize =
getConstantIntValue(mixedSize);
return constSize.has_value() &&
*constSize == staticSize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use isConstantIntValue.

/// Return true if `ofr` is constant integer equal to `value`.
bool isConstantIntValue(OpFoldResult ofr, int64_t value);

/// Helper method to check if a `subview` operation is trivially a no-op. This
/// is the case if the all offsets are zero, all strides are 1, and the source
/// shape is same as the size of the subview. In such cases, the subview can
/// be folded into its source.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update the comment?

PatternRewriter &rewriter) const override {
if (!isTrivialSubViewOp(subViewOp))
if (!isTrivialSubViewOp(rewriter, subViewOp))
return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is problematic: isTrivialSubViewOp creates new IR, but then you return "failure". This is not allowed in a rewrite pattern. You have to make sure that there is no change to the input IR when you return "failure".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is difficult to compare the dynamic sizes without doing this. The other option I considered was using the ValueBoundsOpInterface, but that seems like an expensive check to run in a canonicalization. It would be nice if there were a version of reifyResultShapes that did not create IR, and returned failure if creating new IR was necessary, but that seems like a lot of work to plumb through.

These are the ways I know of to compare the result sizes of a tensor, but if you have any better suggestions, then that would be very helpful!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking into it further, I do see some instances of ValueBoundsOpInterface being used in canonicalization patterns, so I suppose it is the lesser of the two evils. I will update the PR implementation to use it instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this won't work because I don't think it can capture equality of fully dynamic sizes that come from the same SSA value. It is more about computing bounds. So I am back to looking for something like reifyResultShapes, but without creating extra IR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't use ValueBoundsOpInterface during canonicalization. It is quite expensive and should probably removed from the canonicalization patterns that use it today.

I don't have a good solution. Maybe we can build a new pass that folds various index computations and view-like ops based on a single ValueBoundsOpInterface analysis. Maybe that pass could then also replace the "reify ranked shaped ..." pass. Not sure if it's a good idea, just trying to think of a solution...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have an interface that does what I need in downstream IREE, actually. It is like reifyResultShapes, but "read only", as in, it doesn't create new IR. I am going to close this PR and handle this downstream for the time being, since I don't see a great way to do this here in MLIR.

@Max191
Copy link
Contributor Author

Max191 commented May 21, 2025

#140619 (comment)

@Max191 Max191 closed this May 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants