From fe448a514831837fe236b733834cdf1eae477a6f Mon Sep 17 00:00:00 2001 From: Karlo Basioli <68535415+basioli-k@users.noreply.github.com> Date: Tue, 25 Mar 2025 14:41:34 +0000 Subject: [PATCH] =?UTF-8?q?Revert=20"[mlir][memref]=20Verify=20out-of-boun?= =?UTF-8?q?ds=20access=20for=20`memref.subview`=20(#131=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit d4304d85f26984aa772fcddf1f34604e526a6683. --- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 133 ++++++++++++------ .../mlir/Interfaces/ViewLikeInterface.h | 17 ++- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 13 +- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 8 +- .../expand-then-convert-to-llvm.mlir | 18 +-- mlir/test/Dialect/Linalg/promote.mlir | 28 ++-- mlir/test/Dialect/MemRef/canonicalize.mlir | 22 +-- .../MemRef/expand-strided-metadata.mlir | 66 ++++----- .../Dialect/MemRef/fold-memref-alias-ops.mlir | 14 +- mlir/test/Dialect/MemRef/invalid.mlir | 16 --- mlir/test/Dialect/MemRef/subview.mlir | 2 +- mlir/test/Transforms/canonicalize.mlir | 72 +++++----- mlir/test/Transforms/compose-subview.mlir | 8 +- 13 files changed, 214 insertions(+), 203 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 3edc2433c85ea..134cca5800918 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1859,11 +1859,11 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ ]> { let summary = "memref subview operation"; let description = [{ - The `subview` operation converts a memref type to a memref type which - represents a reduced-size view of the original memref as specified by the - operation's offsets, sizes and strides arguments. + The "subview" operation converts a memref type to another memref type + which represents a reduced-size view of the original memref as specified by + the operation's offsets, sizes and strides arguments. - The `subview` operation supports the following arguments: + The SubView operation supports the following arguments: * source: the "base" memref on which to create a "view" memref. * offsets: memref-rank number of offsets into the "base" memref at which to @@ -1876,73 +1876,118 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value `ShapedType::kDynamic` encodes that the corresponding entry - has a dynamic value. + sentinel value ShapedType::kDynamic encodes that the corresponding entry has + a dynamic value. - A `subview` operation may additionally reduce the rank of the resulting - view by removing dimensions that are statically known to be of size 1. - - In the absence of rank reductions, the resulting memref type is computed - as follows: - ``` - result_sizes[i] = size_operands[i] - result_strides[i] = src_strides[i] * stride_operands[i] - result_offset = src_offset + dot_product(offset_operands, src_strides) - ``` - - The offset, size and stride operands must be in-bounds with respect to the - source memref. When possible, the static operation verifier will detect - out-of-bounds subviews. Subviews that cannot be confirmed to be in-bounds - or out-of-bounds based on compile-time information are valid. However, - performing an out-of-bounds subview at runtime is undefined behavior. + A subview operation may additionally reduce the rank of the resulting view + by removing dimensions that are statically known to be of size 1. Example 1: ```mlir - // Subview of static memref with strided layout at static offsets, sizes - // and strides. - %1 = memref.subview %0[4, 2][8, 2][3, 2] - : memref<64x4xf32, strided<[7, 9], offset: 91>> to - memref<8x2xf32, strided<[21, 18], offset: 137>> + %0 = memref.alloc() : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> + + // Create a sub-view of "base" memref '%0' with offset arguments '%c0', + // dynamic sizes for each dimension, and stride arguments '%c1'. + %1 = memref.subview %0[%c0, %c0][%size0, %size1][%c1, %c1] + : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to + memref (d0 * s1 + d1 + s0)>> ``` Example 2: ```mlir - // Subview of static memref with identity layout at dynamic offsets, sizes + %0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> + + // Create a sub-view of "base" memref '%0' with dynamic offsets, sizes, // and strides. - %1 = memref.subview %0[%off0, %off1][%sz0, %sz1][%str0, %str1] - : memref<64x4xf32> to memref> + // Note that dynamic offsets are represented by the linearized dynamic + // offset symbol 's0' in the subview memref layout map, and that the + // dynamic strides operands, after being applied to the base memref + // strides in each dimension, are represented in the view memref layout + // map as symbols 's1', 's2' and 's3'. + %1 = memref.subview %0[%i, %j, %k][%size0, %size1, %size2][%x, %y, %z] + : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to + memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> ``` Example 3: ```mlir - // Subview of dynamic memref with strided layout at dynamic offsets and - // strides, but static sizes. - %1 = memref.subview %0[%off0, %off1][4, 4][%str0, %str1] - : memref> to - memref<4x4xf32, strided<[?, ?], offset: ?>> + %0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> + + // Subview with constant offsets, sizes and strides. + %1 = memref.subview %0[0, 2, 0][4, 4, 4][1, 1, 1] + : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to + memref<4x4x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>> ``` Example 4: ```mlir - // Rank-reducing subviews. - %1 = memref.subview %0[0, 0, 0][1, 16, 4][1, 1, 1] - : memref<8x16x4xf32> to memref<16x4xf32> - %3 = memref.subview %2[3, 4, 2][1, 6, 3][1, 1, 1] - : memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + %0 = memref.alloc(%arg0, %arg1) : memref + + // Subview with constant size, but dynamic offsets and + // strides. The resulting memref has a static shape, but if the + // base memref has an affine map to describe the layout, the result + // memref also uses an affine map to describe the layout. The + // strides of the result memref is computed as follows: + // + // Let #map1 represents the layout of the base memref, and #map2 + // represents the layout of the result memref. A #mapsubview can be + // constructed to map an index from the result memref to the base + // memref (note that the description below uses more convenient + // naming for symbols, while in affine maps, symbols are + // represented as unsigned numbers that identify that symbol in the + // given affine map. + // + // #mapsubview = (d0, d1)[o0, o1, t0, t1] -> (d0 * t0 + o0, d1 * t1 + o1) + // + // where, o0, o1, ... are offsets, and t0, t1, ... are strides. Then, + // + // #map2 = #map1.compose(#mapsubview) + // + // If the layout map is represented as + // + // #map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0) + // + // then, + // + // #map2 = (d0, d1)[s0, s1, s2, o0, o1, t0, t1] -> + // (d0 * s1 * t0 + d1 * s2 * t1 + o0 * s1 + o1 * s2 + s0) + // + // Representing this canonically + // + // #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0) + // + // where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1. + %1 = memref.subview %0[%i, %j][4, 4][%x, %y] : + : memref (d0 * s1 + d1 * s2 + s0)>> to + memref<4x4xf32, affine_map<(d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)>> + + // Note that the subview op does not guarantee that the result + // memref is "inbounds" w.r.t to base memref. It is upto the client + // to ensure that the subview is accessed in a manner that is + // in-bounds. ``` Example 5: ```mlir - // Identity subview. The subview is the full source memref. - %1 = memref.subview %0[0, 0, 0] [8, 16, 4] [1, 1, 1] - : memref<8x16x4xf32> to memref<8x16x4xf32> + // Rank-reducing subview. + %1 = memref.subview %0[0, 0, 0][1, 16, 4][1, 1, 1] : + memref<8x16x4xf32> to memref<16x4xf32> + + // Original layout: + // (d0, d1, d2) -> (64 * d0 + 16 * d1 + d2) + // Subviewed layout: + // (d0, d1, d2) -> (64 * (d0 + 3) + 4 * (d1 + 4) + d2 + 2) = (64 * d0 + 4 * d1 + d2 + 210) + // After rank reducing: + // (d0, d1) -> (4 * d0 + d1 + 210) + %3 = memref.subview %2[3, 4, 2][1, 6, 3][1, 1, 1] : + memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> ``` - }]; let arguments = (ins AnyMemRef:$source, diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index 14427a97a5502..e74326dba7c80 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -76,7 +76,8 @@ SliceBoundsVerificationResult verifyInBoundsSlice( /// returns the new result type of the op, based on the new offsets, sizes and /// strides. `CastOpFunc` is used to generate a cast op if the result type of /// the op has changed. -template +template class OpWithOffsetSizesAndStridesConstantArgumentFolder final : public OpRewritePattern { public: @@ -94,12 +95,14 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final failed(foldDynamicIndexList(mixedStrides))) return failure(); - // Pattern does not apply if the produced op would not verify. - SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice( - cast(op.getSource().getType()).getShape(), mixedOffsets, - mixedSizes, mixedStrides); - if (!sliceResult.isValid) - return failure(); + if (CheckInBounds) { + // Pattern does not apply if the produced op would not verify. + SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice( + cast(op.getSource().getType()).getShape(), mixedOffsets, + mixedSizes, mixedStrides); + if (!sliceResult.isValid) + return failure(); + } // Compute the new result type. auto resultType = diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 123666848f83a..59434dccc117b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2977,9 +2977,6 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, LogicalResult SubViewOp::verify() { MemRefType baseType = getSourceType(); MemRefType subViewType = getType(); - ArrayRef staticOffsets = getStaticOffsets(); - ArrayRef staticSizes = getStaticSizes(); - ArrayRef staticStrides = getStaticStrides(); // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != subViewType.getMemorySpace()) @@ -2994,7 +2991,7 @@ LogicalResult SubViewOp::verify() { // Compute the expected result type, assuming that there are no rank // reductions. MemRefType expectedType = SubViewOp::inferResultType( - baseType, staticOffsets, staticSizes, staticStrides); + baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()); // Verify all properties of a shaped type: rank, element type and dimension // sizes. This takes into account potential rank reductions. @@ -3028,14 +3025,6 @@ LogicalResult SubViewOp::verify() { return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch, *this, expectedType); - // Verify that offsets, sizes, strides do not run out-of-bounds with respect - // to the base memref. - SliceBoundsVerificationResult boundsResult = - verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes, - staticStrides, /*generateErrorMessage=*/true); - if (!boundsResult.isValid) - return getOperation()->emitError(boundsResult.errorMessage); - return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index d589f627d896e..5f8493de991f3 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2617,10 +2617,10 @@ struct SliceCanonicalizer { void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add< - OpWithOffsetSizesAndStridesConstantArgumentFolder< - ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>, - ExtractSliceOpCastFolder>(context); + results.add, + ExtractSliceOpCastFolder>(context); } // diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index fe91d26d5a251..5517eafb588e8 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -192,7 +192,7 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0> // CHECK-LABEL: func @subview_const_stride_and_offset( // CHECK-SAME: %[[MEM:.*]]: memref<{{.*}}> -func.func @subview_const_stride_and_offset(%0 : memref<64x8xf32, strided<[8, 1], offset: 0>>) -> memref<62x3xf32, strided<[8, 1], offset: 2>> { +func.func @subview_const_stride_and_offset(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>) -> memref<62x3xf32, strided<[4, 1], offset: 8>> { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] @@ -201,21 +201,21 @@ func.func @subview_const_stride_and_offset(%0 : memref<64x8xf32, strided<[8, 1], // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(2 : index) : i64 + // CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[CST_OFF]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST_SIZE0:.*]] = llvm.mlir.constant(62 : index) : i64 // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST_SIZE0]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[CST_STRIDE0:.*]] = llvm.mlir.constant(8 : index) : i64 + // CHECK: %[[CST_STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64 // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST_STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST_SIZE1:.*]] = llvm.mlir.constant(3 : index) : i64 // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[CST_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %1 = memref.subview %0[0, 2][62, 3][1, 1] : - memref<64x8xf32, strided<[8, 1], offset: 0>> - to memref<62x3xf32, strided<[8, 1], offset: 2>> - return %1 : memref<62x3xf32, strided<[8, 1], offset: 2>> + %1 = memref.subview %0[0, 8][62, 3][1, 1] : + memref<64x4xf32, strided<[4, 1], offset: 0>> + to memref<62x3xf32, strided<[4, 1], offset: 8>> + return %1 : memref<62x3xf32, strided<[4, 1], offset: 8>> } // ----- @@ -238,7 +238,7 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index // CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] overflow : i64 - // CHECK: %[[BASE_OFF:.*]] = llvm.mlir.constant(2 : index) : i64 + // CHECK: %[[BASE_OFF:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]] : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 @@ -253,7 +253,7 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of // CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[CST_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %1 = memref.subview %0[%arg1, 2][62, %arg2][%arg0, 1] : + %1 = memref.subview %0[%arg1, 8][62, %arg2][%arg0, 1] : memref<64x4xf32, strided<[4, 1], offset: 0>> to memref<62x?xf32, strided<[?, 1], offset: ?>> return %1 : memref<62x?xf32, strided<[?, 1], offset: ?>> diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir index bab606c3a8169..00b8c649b82c3 100644 --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -287,18 +287,18 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func.func @linalg_generic_update_all_function_inputs_outputs( - // CHECK-SAME: %[[VAL_0:.*]]: memref<8x4xf32, 1>, - // CHECK-SAME: %[[VAL_1:.*]]: memref<8x4xf32, 1>) -> memref<8x4xf32, 1> { -func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<8x4xf32, 1>, %arg1: memref<8x4xf32, 1>) -> memref<8x4xf32, 1> { - // CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x4xf32, 1> - // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]][0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> - // CHECK: %[[VAL_4:.*]] = memref.subview %[[VAL_1]][0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> - // CHECK: %[[VAL_5:.*]] = memref.subview %[[VAL_2]][0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> - - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4xf32, 1> - %subview = memref.subview %arg0[0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> - %subview_0 = memref.subview %arg1[0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> - %subview_1 = memref.subview %alloc[0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + // CHECK-SAME: %[[VAL_0:.*]]: memref<3x4xf32, 1>, + // CHECK-SAME: %[[VAL_1:.*]]: memref<3x4xf32, 1>) -> memref<3x4xf32, 1> { +func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf32, 1>, %arg1: memref<3x4xf32, 1>) -> memref<3x4xf32, 1> { + // CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x4xf32, 1> + // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + // CHECK: %[[VAL_4:.*]] = memref.subview %[[VAL_1]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + // CHECK: %[[VAL_5:.*]] = memref.subview %[[VAL_2]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + + %alloc = memref.alloc() {alignment = 64 : i64} : memref<3x4xf32, 1> + %subview = memref.subview %arg0[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + %subview_0 = memref.subview %arg1[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + %subview_1 = memref.subview %alloc[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> // CHECK: %[[VAL_6:.*]] = arith.constant 0 : index // CHECK: %[[VAL_7:.*]] = arith.constant 4 : index @@ -376,10 +376,10 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<8x4xf // CHECK: memref.dealloc %[[VAL_22]] : memref<48xi8, #gpu.address_space> // CHECK: memref.dealloc %[[VAL_41]] : memref<48xi8, #gpu.address_space> // CHECK: memref.dealloc %[[VAL_60]] : memref<48xi8, #gpu.address_space> - // CHECK: return %[[VAL_2]] : memref<8x4xf32, 1> + // CHECK: return %[[VAL_2]] : memref<3x4xf32, 1> // CHECK: } - return %alloc : memref<8x4xf32, 1> + return %alloc : memref<3x4xf32, 1> } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 5d8a7d3f64e8f..02110bc2892d0 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -635,9 +635,9 @@ func.func @fold_no_op_subview(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, st // ----- -func.func @no_fold_subview_with_non_zero_offset(%arg0 : memref<20x42xf32>) -> memref<20x41xf32, strided<[42, 1], offset: 1>> { - %0 = memref.subview %arg0[0, 1] [20, 41] [1, 1] : memref<20x42xf32> to memref<20x41xf32, strided<[42, 1], offset: 1>> - return %0 : memref<20x41xf32, strided<[42, 1], offset: 1>> +func.func @no_fold_subview_with_non_zero_offset(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 1], offset: 1>> { + %0 = memref.subview %arg0[0, 1] [20, 42] [1, 1] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 1], offset: 1>> + return %0 : memref<20x42xf32, strided<[42, 1], offset: 1>> } // CHECK-LABEL: func @no_fold_subview_with_non_zero_offset( // CHECK: %[[SUBVIEW:.+]] = memref.subview @@ -645,9 +645,9 @@ func.func @no_fold_subview_with_non_zero_offset(%arg0 : memref<20x42xf32>) -> me // ----- -func.func @no_fold_subview_with_non_unit_stride(%arg0 : memref<20x42xf32>) -> memref<20x5xf32, strided<[42, 2]>> { - %0 = memref.subview %arg0[0, 0] [20, 5] [1, 2] : memref<20x42xf32> to memref<20x5xf32, strided<[42, 2]>> - return %0 : memref<20x5xf32, strided<[42, 2]>> +func.func @no_fold_subview_with_non_unit_stride(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 2]>> { + %0 = memref.subview %arg0[0, 0] [20, 42] [1, 2] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 2]>> + return %0 : memref<20x42xf32, strided<[42, 2]>> } // CHECK-LABEL: func @no_fold_subview_with_non_unit_stride( // CHECK: %[[SUBVIEW:.+]] = memref.subview @@ -655,16 +655,6 @@ func.func @no_fold_subview_with_non_unit_stride(%arg0 : memref<20x42xf32>) -> me // ----- -// CHECK-LABEL: func @no_fold_invalid_dynamic_slice -// CHECK: memref.subview %arg0[2] [%{{.*}}] [1] : memref<10xf32> to memref> -func.func @no_fold_invalid_dynamic_slice(%arg0: memref<10xf32>) -> memref> { - %c11 = arith.constant 11 : index - %0 = memref.subview %arg0 [2][%c11][1] : memref<10xf32> to memref> - func.return %0 : memref> -} - -// ----- - func.func @no_fold_dynamic_no_op_subview(%arg0 : memref) -> memref> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 1e6b0111fa4c7..647731db439c0 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -119,39 +119,39 @@ func.func @extract_strided_metadata_of_subview(%base: memref<5x4xf32>) // when dynamic sizes are involved. // See extract_strided_metadata_of_subview for an explanation of the actual // expansion. -// Orig strides: [384, 24, 1] +// Orig strides: [64, 4, 1] // Sub strides: [1, 1, 1] -// => New strides: [384, 24, 1] +// => New strides: [64, 4, 1] // // Orig offset: 0 // Sub offsets: [3, 4, 2] -// => Final offset: 3 * 384 + 4 * 24 + 2 * 1 + 0 == 1250 +// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 // // Final sizes == subview sizes == [%size, 6, 3] // // CHECK-LABEL: func @extract_strided_metadata_of_subview_with_dynamic_size -// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x24xf32>, +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, // CHECK-SAME: %[[DYN_SIZE:.*]]: index) // -// CHECK-DAG: %[[C1250:.*]] = arith.constant 1250 : index -// CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index +// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[C24:.*]] = arith.constant 24 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // // CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] // -// CHECK: return %[[BASE]], %[[C1250]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[C384]], %[[C24]], %[[C1]] +// CHECK: return %[[BASE]], %[[C210]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[C64]], %[[C4]], %[[C1]] func.func @extract_strided_metadata_of_subview_with_dynamic_size( - %base: memref<8x16x24xf32>, %size: index) + %base: memref<8x16x4xf32>, %size: index) -> (memref, index, index, index, index, index, index, index) { %subview = memref.subview %base[3, 4, 2][%size, 6, 3][1, 1, 1] : - memref<8x16x24xf32> to memref> + memref<8x16x4xf32> to memref> %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : - memref> + memref> -> memref, index, index, index, index, index, index, index return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : @@ -167,37 +167,37 @@ func.func @extract_strided_metadata_of_subview_with_dynamic_size( // See extract_strided_metadata_of_subview for an explanation of the actual // expansion. // -// Orig strides: [384, 24, 1] +// Orig strides: [64, 4, 1] // Sub strides: [1, 1, 1] -// => New strides: [384, 24, 1] -// Final strides == filterOutReducedDim(new strides, 0) == [24 , 1] +// => New strides: [64, 4, 1] +// Final strides == filterOutReducedDim(new strides, 0) == [4 , 1] // // Orig offset: 0 // Sub offsets: [3, 4, 2] -// => Final offset: 3 * 384 + 4 * 24 + 2 * 1 + 0 == 1250 +// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 // // Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3] // // CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview -// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x24xf32>) +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>) // -// CHECK-DAG: %[[C1250:.*]] = arith.constant 1250 : index +// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[C24:.*]] = arith.constant 24 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // // CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] // -// CHECK: return %[[BASE]], %[[C1250]], %[[C6]], %[[C3]], %[[C24]], %[[C1]] -func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x24xf32>) +// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[C4]], %[[C1]] +func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4xf32>) -> (memref, index, index, index, index, index) { %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, 1, 1] : - memref<8x16x24xf32> to memref<6x3xf32, strided<[24, 1], offset: 1250>> + memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : - memref<6x3xf32, strided<[24, 1], offset: 1250>> + memref<6x3xf32, strided<[4,1], offset: 210>> -> memref, index, index, index, index, index return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : @@ -215,21 +215,21 @@ func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x2 // See extract_strided_metadata_of_subview for an explanation of the actual // expansion. // -// Orig strides: [384, 24, 1] +// Orig strides: [64, 4, 1] // Sub strides: [1, %stride, 1] -// => New strides: [384, 24 * %stride, 1] -// Final strides == filterOutReducedDim(new strides, 0) == [24 * %stride , 1] +// => New strides: [64, 4 * %stride, 1] +// Final strides == filterOutReducedDim(new strides, 0) == [4 * %stride , 1] // // Orig offset: 0 // Sub offsets: [3, 4, 2] -// => Final offset: 3 * 384 + 4 * 24 + 2 * 1 + 0 == 1250 +// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 // -// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 24)> +// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)> // CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides -// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x24xf32>, +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, // CHECK-SAME: %[[DYN_STRIDE:.*]]: index) // -// CHECK-DAG: %[[C1250:.*]] = arith.constant 1250 : index +// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -238,16 +238,16 @@ func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x2 // // CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]] // -// CHECK: return %[[BASE]], %[[C1250]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] +// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides( - %base: memref<8x16x24xf32>, %stride: index) + %base: memref<8x16x4xf32>, %stride: index) -> (memref, index, index, index, index, index) { %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] : - memref<8x16x24xf32> to memref<6x3xf32, strided<[?, 1], offset: 1250>> + memref<8x16x4xf32> to memref<6x3xf32, strided<[?, 1], offset: 210>> %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : - memref<6x3xf32, strided<[?, 1], offset: 1250>> + memref<6x3xf32, strided<[?, 1], offset: 210>> -> memref, index, index, index, index, index return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 067cdb5c5fd20..327cacf7d9a20 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -632,17 +632,17 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)> // CHECK-LABEL: func @subview_of_subview( -// CHECK-SAME: %[[m:.*]]: memref<8x1024xf32, 3>, %[[pos:.*]]: index +// CHECK-SAME: %[[m:.*]]: memref<1x1024xf32, 3>, %[[pos:.*]]: index // CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%arg1] -// CHECK: memref.subview %arg0[4, %[[add]]] [1, 1] [1, 1] : memref<8x1024xf32, 3> to memref, 3> -func.func @subview_of_subview(%m: memref<8x1024xf32, 3>, %pos: index) +// CHECK: memref.subview %arg0[4, %[[add]]] [1, 1] [1, 1] : memref<1x1024xf32, 3> to memref, 3> +func.func @subview_of_subview(%m: memref<1x1024xf32, 3>, %pos: index) -> memref, 3> { - %0 = memref.subview %m[3, %pos] [5, 7] [1, 1] - : memref<8x1024xf32, 3> - to memref<5x7xf32, strided<[1024, 1], offset: ?>, 3> + %0 = memref.subview %m[3, %pos] [1, 2] [1, 1] + : memref<1x1024xf32, 3> + to memref<1x2xf32, strided<[1024, 1], offset: ?>, 3> %1 = memref.subview %0[1, 2] [1, 1] [1, 1] - : memref<5x7xf32, strided<[1024, 1], offset: ?>, 3> + : memref<1x2xf32, strided<[1024, 1], offset: ?>, 3> to memref, 3> return %1 : memref, 3> } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 34fc4775924e7..f72ad48245f81 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -723,22 +723,6 @@ func.func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { // ----- -func.func @invalid_subview(%arg0: memref<10xf32>) { - // expected-error@+1 {{offset 0 is out-of-bounds: 10 >= 10}} - %0 = memref.subview %arg0 [10][1][1] : memref<10xf32> to memref<1xf32, strided<[1], offset: 10>> - return -} - -// ----- - -func.func @invalid_subview(%arg0: memref<9xf32>) { - // expected-error@+1 {{slice along dimension 0 runs out-of-bounds: 9 >= 9}} - %0 = memref.subview %arg0 [3][4][2] : memref<9xf32> to memref<4xf32, strided<[2], offset: 3>> - return -} - -// ----- - func.func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = memref.alloc() : memref<8x16x4xf32> // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, strided<[64, 4, 1]>>' or a rank-reduced version. (mismatch of result sizes)}} diff --git a/mlir/test/Dialect/MemRef/subview.mlir b/mlir/test/Dialect/MemRef/subview.mlir index fd8aaaf86b2d8..135a1124066e4 100644 --- a/mlir/test/Dialect/MemRef/subview.mlir +++ b/mlir/test/Dialect/MemRef/subview.mlir @@ -90,7 +90,7 @@ func.func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK: memref.subview %{{.*}}[0, 0, 0] [1, 16, 4] [1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> %21 = memref.subview %20[0, 0, 0][1, 16, 4][1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> - %22 = memref.subview %20[3, 4, 1][1, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 209>> + %22 = memref.subview %20[3, 4, 2][1, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> %23 = memref.alloc() : memref %78 = memref.subview %23[] [] [] : memref to memref diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 8e02c06a0a293..9b74362b6ee75 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -752,83 +752,83 @@ func.func @subview(%arg0 : index, %arg1 : index) -> (index, index) { %c15 = arith.constant 15 : index // CHECK: %[[ALLOC0:.*]] = memref.alloc() - %0 = memref.alloc() : memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> + %0 = memref.alloc() : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> // Test: subview with constant base memref and constant operands is folded. // Note that the subview uses the base memrefs layout map because it used // zero offset and unit stride arguments. // CHECK: memref.subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [1, 1, 1] : - // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> - // CHECK-SAME: to memref<7x11x2xf32, strided<[6144, 64, 1]>> + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> + // CHECK-SAME: to memref<7x11x2xf32, strided<[64, 4, 1]>> %1 = memref.subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] - : memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to memref> %v0 = memref.load %1[%c0, %c0, %c0] : memref> // Test: subview with one dynamic operand can also be folded. // CHECK: memref.subview %[[ALLOC0]][0, %[[ARG0]], 0] [7, 11, 15] [1, 1, 1] : - // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> - // CHECK-SAME: to memref<7x11x15xf32, strided<[6144, 64, 1], offset: ?>> + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> + // CHECK-SAME: to memref<7x11x15xf32, strided<[64, 4, 1], offset: ?>> %2 = memref.subview %0[%c0, %arg0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] - : memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to memref> memref.store %v0, %2[%c0, %c0, %c0] : memref> // CHECK: %[[ALLOC1:.*]] = memref.alloc(%[[ARG0]]) %3 = memref.alloc(%arg0) : memref> // Test: subview with constant operands but dynamic base memref is folded as long as the strides and offset of the base memref are static. - // CHECK: memref.subview %[[ALLOC1]][0, 0, 0] [7, 11, 2] [1, 1, 1] : + // CHECK: memref.subview %[[ALLOC1]][0, 0, 0] [7, 11, 15] [1, 1, 1] : // CHECK-SAME: memref> - // CHECK-SAME: to memref<7x11x2xf32, strided<[64, 4, 1]>> - %4 = memref.subview %3[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] + // CHECK-SAME: to memref<7x11x15xf32, strided<[64, 4, 1]>> + %4 = memref.subview %3[%c0, %c0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] : memref> to memref> memref.store %v0, %4[%c0, %c0, %c0] : memref> // Test: subview offset operands are folded correctly w.r.t. base strides. // CHECK: memref.subview %[[ALLOC0]][1, 2, 7] [7, 11, 2] [1, 1, 1] : - // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> to - // CHECK-SAME: memref<7x11x2xf32, strided<[6144, 64, 1], offset: 6279>> + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to + // CHECK-SAME: memref<7x11x2xf32, strided<[64, 4, 1], offset: 79>> %5 = memref.subview %0[%c1, %c2, %c7] [%c7, %c11, %c2] [%c1, %c1, %c1] - : memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to memref> memref.store %v0, %5[%c0, %c0, %c0] : memref> // Test: subview stride operands are folded correctly w.r.t. base strides. // CHECK: memref.subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [2, 7, 11] : - // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> - // CHECK-SAME: to memref<7x11x2xf32, strided<[12288, 448, 11]>> + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> + // CHECK-SAME: to memref<7x11x2xf32, strided<[128, 28, 11]>> %6 = memref.subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c2, %c7, %c11] - : memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to memref> memref.store %v0, %6[%c0, %c0, %c0] : memref> // Test: subview shape are folded, but offsets and strides are not even if base memref is static // CHECK: memref.subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : - // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> to + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to // CHECK-SAME: memref<7x11x2xf32, strided<[?, ?, ?], offset: ?>> %10 = memref.subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : - memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to + memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to memref> memref.store %v0, %10[%arg1, %arg1, %arg1] : memref> // Test: subview strides are folded, but offsets and shape are not even if base memref is static // CHECK: memref.subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 7, 11] : - // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> to - // CHECK-SAME: memref> + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to + // CHECK-SAME: memref> %11 = memref.subview %0[%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] [%c2, %c7, %c11] : - memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to + memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to memref> memref.store %v0, %11[%arg0, %arg0, %arg0] : memref> // Test: subview offsets are folded, but strides and shape are not even if base memref is static // CHECK: memref.subview %[[ALLOC0]][1, 2, 7] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] : - // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> to - // CHECK-SAME: memref> + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to + // CHECK-SAME: memref> %13 = memref.subview %0[%c1, %c2, %c7] [%arg1, %arg1, %arg1] [%arg0, %arg0, %arg0] : - memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to + memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to memref> memref.store %v0, %13[%arg1, %arg1, %arg1] : memref> @@ -862,27 +862,27 @@ func.func @subview(%arg0 : index, %arg1 : index) -> (index, index) { memref> memref.store %v0, %17[%arg0, %arg0, %arg0] : memref> - // CHECK: %[[ALLOC3:.*]] = memref.alloc() : memref<128x64xf32> - %18 = memref.alloc() : memref<128x64xf32> + // CHECK: %[[ALLOC3:.*]] = memref.alloc() : memref<12x4xf32> + %18 = memref.alloc() : memref<12x4xf32> %c4 = arith.constant 4 : index // TEST: subview strides are maintained when sizes are folded // CHECK: memref.subview %[[ALLOC3]][%arg1, %arg1] [2, 4] [1, 1] : - // CHECK-SAME: memref<128x64xf32> to - // CHECK-SAME: memref<2x4xf32, strided<[64, 1], offset: ?> + // CHECK-SAME: memref<12x4xf32> to + // CHECK-SAME: memref<2x4xf32, strided<[4, 1], offset: ?> %19 = memref.subview %18[%arg1, %arg1] [%c2, %c4] [1, 1] : - memref<128x64xf32> to - memref> - memref.store %v0, %19[%arg1, %arg1] : memref> + memref<12x4xf32> to + memref> + memref.store %v0, %19[%arg1, %arg1] : memref> // TEST: subview strides and sizes are maintained when offsets are folded // CHECK: memref.subview %[[ALLOC3]][2, 4] [12, 4] [1, 1] : - // CHECK-SAME: memref<128x64xf32> to - // CHECK-SAME: memref<12x4xf32, strided<[64, 1], offset: 132>> + // CHECK-SAME: memref<12x4xf32> to + // CHECK-SAME: memref<12x4xf32, strided<[4, 1], offset: 12>> %20 = memref.subview %18[%c2, %c4] [12, 4] [1, 1] : - memref<128x64xf32> to - memref<12x4xf32, strided<[64, 1], offset: ?>> - memref.store %v0, %20[%arg1, %arg1] : memref<12x4xf32, strided<[64, 1], offset: ?>> + memref<12x4xf32> to + memref<12x4xf32, strided<[4, 1], offset: ?>> + memref.store %v0, %20[%arg1, %arg1] : memref<12x4xf32, strided<[4, 1], offset: ?>> // Test: dim on subview is rewritten to size operand. %7 = memref.dim %4, %c0 : memref> diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir index 53fbb8a356def..22ffd836c68ed 100644 --- a/mlir/test/Transforms/compose-subview.mlir +++ b/mlir/test/Transforms/compose-subview.mlir @@ -53,10 +53,10 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { -func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> - %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<8x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>> +// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { +func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { + // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> + %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>> %1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>> }