diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 82702789c2913..3ccdcbf8c3be5 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3105,7 +3105,7 @@ FailureOr SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc, /// is the case if the all offsets are zero, all strides are 1, and the source /// shape is same as the size of the subview. In such cases, the subview can /// be folded into its source. -static bool isTrivialSubViewOp(SubViewOp subViewOp) { +static bool isTrivialSubViewOp(OpBuilder &b, SubViewOp subViewOp) { if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) return false; @@ -3127,15 +3127,24 @@ static bool isTrivialSubViewOp(SubViewOp subViewOp) { })) return false; - // Check all size values are static and matches the (static) source shape. + // Check all size values match the source shape. ArrayRef sourceShape = subViewOp.getSourceType().getShape(); - for (const auto &size : llvm::enumerate(mixedSizes)) { - std::optional intValue = getConstantIntValue(size.value()); - if (!intValue || *intValue != sourceShape[size.index()]) - return false; + if (llvm::all_of_zip(mixedSizes, sourceShape, + [](OpFoldResult mixedSize, int64_t staticSize) { + std::optional constSize = + getConstantIntValue(mixedSize); + return constSize.has_value() && + *constSize == staticSize; + })) { + return true; } - // All conditions met. The `SubViewOp` is foldable as a no-op. - return true; + auto sourceOpResult = dyn_cast(subViewOp.getSource()); + if (!sourceOpResult) + return false; + ReifiedRankedShapedTypeDims resultDims; + if (failed(reifyResultShapes(b, sourceOpResult.getOwner(), resultDims))) + return false; + return llvm::equal(mixedSizes, resultDims[sourceOpResult.getResultNumber()]); } namespace { @@ -3206,7 +3215,7 @@ class TrivialSubViewOpFolder final : public OpRewritePattern { LogicalResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { - if (!isTrivialSubViewOp(subViewOp)) + if (!isTrivialSubViewOp(rewriter, subViewOp)) return failure(); if (subViewOp.getSourceType() == subViewOp.getType()) { rewriter.replaceOp(subViewOp, subViewOp.getSource()); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index e7cee7cd85426..ebad9e3eab345 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -70,6 +70,20 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4 // ----- +// CHECK-LABEL: func @subview_of_dynamic_full_size +// CHECK-SAME: %[[ARG0:.+]]: memref +// CHECK-SAME: %[[SIZE:.+]]: index +// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape +// CHECK-NOT: memref.subview +// CHECK: return %[[EXPAND_SHAPE]] : memref +func.func @subview_of_dynamic_full_size(%arg0 : memref, %size : index) -> memref { + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [%size, %size] : memref into memref + %1 = memref.subview %0[0, 0] [%size, %size] [1, 1] : memref to memref + return %1 : memref +} + +// ----- + // CHECK-LABEL: func @negative_subview_of_static_full_size // CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>> // CHECK-SAME: %[[IDX:.+]]: index