Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
PatternRewriter &rewriter) const override {
auto viewLikeOp =
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
// ViewLikeOpInterface by itself doesn't guarantee to preserve the base
// pointer in general and `memref.view` is one such example, so just check
// for a few specific cases.
if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() ||
!isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp))
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
rewriter.modifyOpInPlace(extractOp, [&]() {
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1455,3 +1455,20 @@ func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<2

// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base
// CHECK-NOT: memref.memory_space_cast

// -----

func.func @negative_memref_view_extract_aligned_pointer(%arg0: memref<?xi8>) -> index {
// `extract_aligned_pointer_as_index` must not be folded as `memref.view` can change the base pointer
// CHECK-LABEL: func @negative_memref_view_extract_aligned_pointer
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xi8>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C0]]][] : memref<?xi8> to memref<f32>
// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[VIEW]] : memref<f32> -> index
// CHECK: return %[[PTR]] : index

%c0 = arith.constant 0 : index
%0 = memref.view %arg0[%c0][] : memref<?xi8> to memref<f32>
%1 = memref.extract_aligned_pointer_as_index %0: memref<f32> -> index
return %1 : index
}
Loading