Skip to content

Commit b7e34f8

Browse files
committed
[mlir][memref] Remove invalid extract_aligned_pointer_as_index folding in ExpandStridedMetadata
ViewLikeOpInterface by itself doesn't guarantee to preserve the base pointer in general and `memref.view` is one such example, so limit folder to a few specific ops.
1 parent ea10026 commit b7e34f8

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
959959
PatternRewriter &rewriter) const override {
960960
auto viewLikeOp =
961961
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
962-
if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
962+
// ViewLikeOpInterface by itself doesn't guarantee to preserve the base
963+
// pointer in general and `memref.view` is one such example, so just check
964+
// for a few specific cases.
965+
if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() ||
966+
!isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp))
963967
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
964968
rewriter.modifyOpInPlace(extractOp, [&]() {
965969
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,3 +1455,20 @@ func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<2
14551455

14561456
// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base
14571457
// CHECK-NOT: memref.memory_space_cast
1458+
1459+
// -----
1460+
1461+
func.func @negative_memref_view_extract_aligned_pointer(%arg0: memref<?xi8>) -> index {
1462+
// `extract_aligned_pointer_as_index` must not be folded as `memref.view` can change the base pointer
1463+
// CHECK-LABEL: func @negative_memref_view_extract_aligned_pointer
1464+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xi8>)
1465+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1466+
// CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C0]]][] : memref<?xi8> to memref<f32>
1467+
// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[VIEW]] : memref<f32> -> index
1468+
// CHECK: return %[[PTR]] : index
1469+
1470+
%c0 = arith.constant 0 : index
1471+
%0 = memref.view %arg0[%c0][] : memref<?xi8> to memref<f32>
1472+
%1 = memref.extract_aligned_pointer_as_index %0: memref<f32> -> index
1473+
return %1 : index
1474+
}

0 commit comments

Comments
 (0)