-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][memref]: Collapse strided unit dim even if strides are dynamic #157330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Maya Amrami (amrami) Changes…e dynamic Full diff: https://github.com/llvm/llvm-project/pull/157330.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b59d73d1291c8..3bdeaea300659 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2401,6 +2401,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
for (int64_t idx : llvm::reverse(trailingReassocs)) {
+ // Dimensions of size 1 should be skipped, because their strides are
+ // meaningless and could have any arbitrary value.
+ if (srcShape[idx - 1] == 1)
+ continue;
+
stride = stride * SaturatedInteger::wrap(srcShape[idx]);
// Both source and result stride must have the same static value. In that
@@ -2415,11 +2420,6 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (strict && (stride.saturated || srcStride.saturated))
return failure();
- // Dimensions of size 1 should be skipped, because their strides are
- // meaningless and could have any arbitrary value.
- if (srcShape[idx - 1] == 1)
- continue;
-
if (!stride.saturated && !srcStride.saturated && stride != srcStride)
return failure();
}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 6c2298a3f8acb..50683761db5bf 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -431,7 +431,8 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
%arg4: index,
%arg5: index,
%arg6: index,
- %arg7: memref<4x?x4xf32>) {
+ %arg7: memref<4x?x4xf32>,
+ %arg8: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>>) {
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -480,6 +481,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
%4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
: memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]]
+// CHECK-SAME: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
+ %5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
return
}
|
|
@matthias-springer Can you take a look? 😊 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a problem with this PR, but I'm not working with collapse_shape/expand_shape anymore. @MaheshRavishankar is this safe to merge?
|
@MaheshRavishankar @Groverkss Can you take a look? :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM here
…e dynamic