Skip to content

Commit 285fc64

Browse files
committed
[mlir][memref]: Allow collapse of strided unit dim even if strides are dynamic
1 parent c75c136 commit 285fc64

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,6 +2401,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
24012401
auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
24022402
auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
24032403
for (int64_t idx : llvm::reverse(trailingReassocs)) {
2404+
// Dimensions of size 1 should be skipped, because their strides are
2405+
// meaningless and could have any arbitrary value.
2406+
if (srcShape[idx - 1] == 1)
2407+
continue;
2408+
24042409
stride = stride * SaturatedInteger::wrap(srcShape[idx]);
24052410

24062411
// Both source and result stride must have the same static value. In that
@@ -2415,11 +2420,6 @@ computeCollapsedLayoutMap(MemRefType srcType,
24152420
if (strict && (stride.saturated || srcStride.saturated))
24162421
return failure();
24172422

2418-
// Dimensions of size 1 should be skipped, because their strides are
2419-
// meaningless and could have any arbitrary value.
2420-
if (srcShape[idx - 1] == 1)
2421-
continue;
2422-
24232423
if (!stride.saturated && !srcStride.saturated && stride != srcStride)
24242424
return failure();
24252425
}

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,8 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
431431
%arg4: index,
432432
%arg5: index,
433433
%arg6: index,
434-
%arg7: memref<4x?x4xf32>) {
434+
%arg7: memref<4x?x4xf32>,
435+
%arg8: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>>) {
435436
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
436437
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
437438
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -480,6 +481,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
480481
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
481482
%4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
482483
: memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
484+
485+
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]]
486+
// CHECK-SAME: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
487+
%5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
483488
return
484489
}
485490

0 commit comments

Comments
 (0)