Skip to content

Commit 33d3ba6

Browse files
committed
Revert "[mlir][memref]: Collapse strided unit dim even if strides are dynamic (llvm#157330)"
This reverts commit f74e909.
1 parent 19f43a5 commit 33d3ba6

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2568,11 +2568,6 @@ computeCollapsedLayoutMap(MemRefType srcType,
25682568
auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
25692569
auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
25702570
for (int64_t idx : llvm::reverse(trailingReassocs)) {
2571-
// Dimensions of size 1 should be skipped, because their strides are
2572-
// meaningless and could have any arbitrary value.
2573-
if (srcShape[idx - 1] == 1)
2574-
continue;
2575-
25762571
stride = stride * SaturatedInteger::wrap(srcShape[idx]);
25772572

25782573
// Both source and result stride must have the same static value. In that
@@ -2587,6 +2582,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
25872582
if (strict && (stride.saturated || srcStride.saturated))
25882583
return failure();
25892584

2585+
// Dimensions of size 1 should be skipped, because their strides are
2586+
// meaningless and could have any arbitrary value.
2587+
if (srcShape[idx - 1] == 1)
2588+
continue;
2589+
25902590
if (!stride.saturated && !srcStride.saturated && stride != srcStride)
25912591
return failure();
25922592
}

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,7 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
440440
%arg4: index,
441441
%arg5: index,
442442
%arg6: index,
443-
%arg7: memref<4x?x4xf32>,
444-
%arg8: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>>) {
443+
%arg7: memref<4x?x4xf32>) {
445444
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
446445
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
447446
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -490,10 +489,6 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
490489
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
491490
%4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
492491
: memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
493-
494-
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]]
495-
// CHECK-SAME: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
496-
%5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
497492
return
498493
}
499494

0 commit comments

Comments
 (0)