diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index b0fb5b0785142..09bb3932ef293 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -2333,6 +2333,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [ let hasCanonicalizer = 1; let hasVerifier = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 3c4d2562e6999..4ce251d6dd224 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3463,6 +3463,16 @@ LogicalResult ViewOp::verify() { Value ViewOp::getViewSource() { return getSource(); } +OpFoldResult ViewOp::fold(FoldAdaptor adaptor) { + MemRefType sourceMemrefType = getSource().getType(); + MemRefType resultMemrefType = getResult().getType(); + + if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape()) + return getViewSource(); + + return {}; +} + namespace { struct ViewOpShapeFolder : public OpRewritePattern { diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index a91e54a126100..16b7a5c8bcb08 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -1208,3 +1208,24 @@ func.func @fold_assume_alignment_chain(%0: memref<128xf32>) -> memref<128xf32> { // CHECK: return %[[ALIGN]] return %2 : memref<128xf32> } + +// ----- + +// CHECK-LABEL: func @fold_view_same_source_result_types +func.func @fold_view_same_source_result_types(%0: memref<128xi8>) -> memref<128xi8> { + %c0 = arith.constant 0: index + // CHECK-NOT: memref.view + %res = memref.view %0[%c0][] : memref<128xi8> to memref<128xi8> + return %res : memref<128xi8> +} + +// ----- + +// CHECK-LABEL: func @non_fold_view_same_source_res_types +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +func.func @non_fold_view_same_source_res_types(%0: memref, %arg0 : index) -> memref { + %c0 = arith.constant 0: index + // CHECK: memref.view + %res = memref.view %0[%c0][%arg0] : memref to memref + return %res : memref +}