File tree Expand file tree Collapse file tree 2 files changed +22
-1
lines changed
lib/Dialect/MemRef/Transforms Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Original file line number Diff line number Diff line change @@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
959959 PatternRewriter &rewriter) const override {
960960 auto viewLikeOp =
961961 extractOp.getSource ().getDefiningOp <ViewLikeOpInterface>();
962- if (!viewLikeOp || extractOp.getSource () != viewLikeOp.getViewDest ())
962+ // ViewLikeOpInterface by itself doesn't guarantee to preserve the base
963+ // pointer in general and `memref.view` is one such example, so just check
964+ // for a few specific cases.
965+ if (!viewLikeOp || extractOp.getSource () != viewLikeOp.getViewDest () ||
966+ !isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp))
963967 return rewriter.notifyMatchFailure (extractOp, " not a ViewLike source" );
964968 rewriter.modifyOpInPlace (extractOp, [&]() {
965969 extractOp.getSourceMutable ().assign (viewLikeOp.getViewSource ());
Original file line number Diff line number Diff line change @@ -1455,3 +1455,20 @@ func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<2
14551455
14561456// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base
14571457// CHECK-NOT: memref.memory_space_cast
1458+
1459+ // -----
1460+
1461+ func.func @negative_memref_view_extract_aligned_pointer (%arg0: memref <?xi8 >) -> index {
1462+ // `extract_aligned_pointer_as_index` must not be folded as `memref.view` can change the base pointer
1463+ // CHECK-LABEL: func @negative_memref_view_extract_aligned_pointer
1464+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xi8>)
1465+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
1466+ // CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C0]]][] : memref<?xi8> to memref<f32>
1467+ // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[VIEW]] : memref<f32> -> index
1468+ // CHECK: return %[[PTR]] : index
1469+
1470+ %c0 = arith.constant 0 : index
1471+ %0 = memref.view %arg0 [%c0 ][] : memref <?xi8 > to memref <f32 >
1472+ %1 = memref.extract_aligned_pointer_as_index %0: memref <f32 > -> index
1473+ return %1 : index
1474+ }
You can’t perform that action at this time.
0 commit comments