Skip to content

Commit 0a34d37

Browse files
authored
[mlir][memref] Remove invalid extract_aligned_pointer_as_index folding in ExpandStridedMetadata (llvm#167615)
`RewriteExtractAlignedPointerAsIndexOfViewLikeOp` tries to propagate `extract_aligned_pointer_as_index` through the view ops. `ViewLikeOpInterface` by itself doesn't guarantee to preserve the base pointer and `memref.view` is one such example, so limit pattern to a few specific ops.
1 parent 87da620 commit 0a34d37

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
959959
PatternRewriter &rewriter) const override {
960960
auto viewLikeOp =
961961
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
962-
if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
962+
// ViewLikeOpInterface by itself doesn't guarantee to preserve the base
963+
// pointer in general and `memref.view` is one such example, so just check
964+
// for a few specific cases.
965+
if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() ||
966+
!isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp))
963967
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
964968
rewriter.modifyOpInPlace(extractOp, [&]() {
965969
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,3 +1455,20 @@ func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<2
14551455

14561456
// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base
14571457
// CHECK-NOT: memref.memory_space_cast
1458+
1459+
// -----
1460+
1461+
func.func @negative_memref_view_extract_aligned_pointer(%arg0: memref<?xi8>) -> index {
1462+
// `extract_aligned_pointer_as_index` must not be folded as `memref.view` can change the base pointer
1463+
// CHECK-LABEL: func @negative_memref_view_extract_aligned_pointer
1464+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xi8>)
1465+
// CHECK: %[[C10:.*]] = arith.constant 10 : index
1466+
// CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C10]]][] : memref<?xi8> to memref<f32>
1467+
// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[VIEW]] : memref<f32> -> index
1468+
// CHECK: return %[[PTR]] : index
1469+
1470+
%c10 = arith.constant 10 : index
1471+
%0 = memref.view %arg0[%c10][] : memref<?xi8> to memref<f32>
1472+
%1 = memref.extract_aligned_pointer_as_index %0: memref<f32> -> index
1473+
return %1 : index
1474+
}

0 commit comments

Comments
 (0)