diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 134cca5800918..3edc2433c85ea 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1859,11 +1859,11 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ ]> { let summary = "memref subview operation"; let description = [{ - The "subview" operation converts a memref type to another memref type - which represents a reduced-size view of the original memref as specified by - the operation's offsets, sizes and strides arguments. + The `subview` operation converts a memref type to a memref type which + represents a reduced-size view of the original memref as specified by the + operation's offsets, sizes and strides arguments. - The SubView operation supports the following arguments: + The `subview` operation supports the following arguments: * source: the "base" memref on which to create a "view" memref. * offsets: memref-rank number of offsets into the "base" memref at which to @@ -1876,118 +1876,73 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic encodes that the corresponding entry has - a dynamic value. + sentinel value `ShapedType::kDynamic` encodes that the corresponding entry + has a dynamic value. - A subview operation may additionally reduce the rank of the resulting view - by removing dimensions that are statically known to be of size 1. + A `subview` operation may additionally reduce the rank of the resulting + view by removing dimensions that are statically known to be of size 1. + + In the absence of rank reductions, the resulting memref type is computed + as follows: + ``` + result_sizes[i] = size_operands[i] + result_strides[i] = src_strides[i] * stride_operands[i] + result_offset = src_offset + dot_product(offset_operands, src_strides) + ``` + + The offset, size and stride operands must be in-bounds with respect to the + source memref. When possible, the static operation verifier will detect + out-of-bounds subviews. Subviews that cannot be confirmed to be in-bounds + or out-of-bounds based on compile-time information are valid. However, + performing an out-of-bounds subview at runtime is undefined behavior. Example 1: ```mlir - %0 = memref.alloc() : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> - - // Create a sub-view of "base" memref '%0' with offset arguments '%c0', - // dynamic sizes for each dimension, and stride arguments '%c1'. - %1 = memref.subview %0[%c0, %c0][%size0, %size1][%c1, %c1] - : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to - memref (d0 * s1 + d1 + s0)>> + // Subview of static memref with strided layout at static offsets, sizes + // and strides. + %1 = memref.subview %0[4, 2][8, 2][3, 2] + : memref<64x4xf32, strided<[7, 9], offset: 91>> to + memref<8x2xf32, strided<[21, 18], offset: 137>> ``` Example 2: ```mlir - %0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> - - // Create a sub-view of "base" memref '%0' with dynamic offsets, sizes, + // Subview of static memref with identity layout at dynamic offsets, sizes // and strides. - // Note that dynamic offsets are represented by the linearized dynamic - // offset symbol 's0' in the subview memref layout map, and that the - // dynamic strides operands, after being applied to the base memref - // strides in each dimension, are represented in the view memref layout - // map as symbols 's1', 's2' and 's3'. - %1 = memref.subview %0[%i, %j, %k][%size0, %size1, %size2][%x, %y, %z] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + %1 = memref.subview %0[%off0, %off1][%sz0, %sz1][%str0, %str1] + : memref<64x4xf32> to memref> ``` Example 3: ```mlir - %0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> - - // Subview with constant offsets, sizes and strides. - %1 = memref.subview %0[0, 2, 0][4, 4, 4][1, 1, 1] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref<4x4x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>> + // Subview of dynamic memref with strided layout at dynamic offsets and + // strides, but static sizes. + %1 = memref.subview %0[%off0, %off1][4, 4][%str0, %str1] + : memref> to + memref<4x4xf32, strided<[?, ?], offset: ?>> ``` Example 4: ```mlir - %0 = memref.alloc(%arg0, %arg1) : memref - - // Subview with constant size, but dynamic offsets and - // strides. The resulting memref has a static shape, but if the - // base memref has an affine map to describe the layout, the result - // memref also uses an affine map to describe the layout. The - // strides of the result memref is computed as follows: - // - // Let #map1 represents the layout of the base memref, and #map2 - // represents the layout of the result memref. A #mapsubview can be - // constructed to map an index from the result memref to the base - // memref (note that the description below uses more convenient - // naming for symbols, while in affine maps, symbols are - // represented as unsigned numbers that identify that symbol in the - // given affine map. - // - // #mapsubview = (d0, d1)[o0, o1, t0, t1] -> (d0 * t0 + o0, d1 * t1 + o1) - // - // where, o0, o1, ... are offsets, and t0, t1, ... are strides. Then, - // - // #map2 = #map1.compose(#mapsubview) - // - // If the layout map is represented as - // - // #map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0) - // - // then, - // - // #map2 = (d0, d1)[s0, s1, s2, o0, o1, t0, t1] -> - // (d0 * s1 * t0 + d1 * s2 * t1 + o0 * s1 + o1 * s2 + s0) - // - // Representing this canonically - // - // #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0) - // - // where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1. - %1 = memref.subview %0[%i, %j][4, 4][%x, %y] : - : memref (d0 * s1 + d1 * s2 + s0)>> to - memref<4x4xf32, affine_map<(d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)>> - - // Note that the subview op does not guarantee that the result - // memref is "inbounds" w.r.t to base memref. It is upto the client - // to ensure that the subview is accessed in a manner that is - // in-bounds. + // Rank-reducing subviews. + %1 = memref.subview %0[0, 0, 0][1, 16, 4][1, 1, 1] + : memref<8x16x4xf32> to memref<16x4xf32> + %3 = memref.subview %2[3, 4, 2][1, 6, 3][1, 1, 1] + : memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> ``` Example 5: ```mlir - // Rank-reducing subview. - %1 = memref.subview %0[0, 0, 0][1, 16, 4][1, 1, 1] : - memref<8x16x4xf32> to memref<16x4xf32> - - // Original layout: - // (d0, d1, d2) -> (64 * d0 + 16 * d1 + d2) - // Subviewed layout: - // (d0, d1, d2) -> (64 * (d0 + 3) + 4 * (d1 + 4) + d2 + 2) = (64 * d0 + 4 * d1 + d2 + 210) - // After rank reducing: - // (d0, d1) -> (4 * d0 + d1 + 210) - %3 = memref.subview %2[3, 4, 2][1, 6, 3][1, 1, 1] : - memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + // Identity subview. The subview is the full source memref. + %1 = memref.subview %0[0, 0, 0] [8, 16, 4] [1, 1, 1] + : memref<8x16x4xf32> to memref<8x16x4xf32> ``` + }]; let arguments = (ins AnyMemRef:$source, diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index e74326dba7c80..14427a97a5502 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -76,8 +76,7 @@ SliceBoundsVerificationResult verifyInBoundsSlice( /// returns the new result type of the op, based on the new offsets, sizes and /// strides. `CastOpFunc` is used to generate a cast op if the result type of /// the op has changed. -template +template class OpWithOffsetSizesAndStridesConstantArgumentFolder final : public OpRewritePattern { public: @@ -95,14 +94,12 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final failed(foldDynamicIndexList(mixedStrides))) return failure(); - if (CheckInBounds) { - // Pattern does not apply if the produced op would not verify. - SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice( - cast(op.getSource().getType()).getShape(), mixedOffsets, - mixedSizes, mixedStrides); - if (!sliceResult.isValid) - return failure(); - } + // Pattern does not apply if the produced op would not verify. + SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice( + cast(op.getSource().getType()).getShape(), mixedOffsets, + mixedSizes, mixedStrides); + if (!sliceResult.isValid) + return failure(); // Compute the new result type. auto resultType = diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 59434dccc117b..123666848f83a 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2977,6 +2977,9 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, LogicalResult SubViewOp::verify() { MemRefType baseType = getSourceType(); MemRefType subViewType = getType(); + ArrayRef staticOffsets = getStaticOffsets(); + ArrayRef staticSizes = getStaticSizes(); + ArrayRef staticStrides = getStaticStrides(); // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != subViewType.getMemorySpace()) @@ -2991,7 +2994,7 @@ LogicalResult SubViewOp::verify() { // Compute the expected result type, assuming that there are no rank // reductions. MemRefType expectedType = SubViewOp::inferResultType( - baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()); + baseType, staticOffsets, staticSizes, staticStrides); // Verify all properties of a shaped type: rank, element type and dimension // sizes. This takes into account potential rank reductions. @@ -3025,6 +3028,14 @@ LogicalResult SubViewOp::verify() { return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch, *this, expectedType); + // Verify that offsets, sizes, strides do not run out-of-bounds with respect + // to the base memref. + SliceBoundsVerificationResult boundsResult = + verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes, + staticStrides, /*generateErrorMessage=*/true); + if (!boundsResult.isValid) + return getOperation()->emitError(boundsResult.errorMessage); + return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 5f8493de991f3..d589f627d896e 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2617,10 +2617,10 @@ struct SliceCanonicalizer { void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ExtractSliceOpCastFolder>(context); + results.add< + OpWithOffsetSizesAndStridesConstantArgumentFolder< + ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>, + ExtractSliceOpCastFolder>(context); } // diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index 5517eafb588e8..fe91d26d5a251 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -192,7 +192,7 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0> // CHECK-LABEL: func @subview_const_stride_and_offset( // CHECK-SAME: %[[MEM:.*]]: memref<{{.*}}> -func.func @subview_const_stride_and_offset(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>) -> memref<62x3xf32, strided<[4, 1], offset: 8>> { +func.func @subview_const_stride_and_offset(%0 : memref<64x8xf32, strided<[8, 1], offset: 0>>) -> memref<62x3xf32, strided<[8, 1], offset: 2>> { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] @@ -201,21 +201,21 @@ func.func @subview_const_stride_and_offset(%0 : memref<64x4xf32, strided<[4, 1], // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(8 : index) : i64 + // CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[CST_OFF]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST_SIZE0:.*]] = llvm.mlir.constant(62 : index) : i64 // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST_SIZE0]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[CST_STRIDE0:.*]] = llvm.mlir.constant(4 : index) : i64 + // CHECK: %[[CST_STRIDE0:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST_STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST_SIZE1:.*]] = llvm.mlir.constant(3 : index) : i64 // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[CST_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %1 = memref.subview %0[0, 8][62, 3][1, 1] : - memref<64x4xf32, strided<[4, 1], offset: 0>> - to memref<62x3xf32, strided<[4, 1], offset: 8>> - return %1 : memref<62x3xf32, strided<[4, 1], offset: 8>> + %1 = memref.subview %0[0, 2][62, 3][1, 1] : + memref<64x8xf32, strided<[8, 1], offset: 0>> + to memref<62x3xf32, strided<[8, 1], offset: 2>> + return %1 : memref<62x3xf32, strided<[8, 1], offset: 2>> } // ----- @@ -238,7 +238,7 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[DESCSTRIDE0]] : i64 to index // CHECK: %[[DESCSTRIDE0_V2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 // CHECK: %[[OFF0:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] overflow : i64 - // CHECK: %[[BASE_OFF:.*]] = llvm.mlir.constant(8 : index) : i64 + // CHECK: %[[BASE_OFF:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF0]], %[[BASE_OFF]] : i64 // CHECK: %[[TMP:.*]] = builtin.unrealized_conversion_cast %[[OFF2]] : i64 to index // CHECK: %[[OFF2:.*]] = builtin.unrealized_conversion_cast %[[TMP]] : index to i64 @@ -253,7 +253,7 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of // CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[CST_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %1 = memref.subview %0[%arg1, 8][62, %arg2][%arg0, 1] : + %1 = memref.subview %0[%arg1, 2][62, %arg2][%arg0, 1] : memref<64x4xf32, strided<[4, 1], offset: 0>> to memref<62x?xf32, strided<[?, 1], offset: ?>> return %1 : memref<62x?xf32, strided<[?, 1], offset: ?>> diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir index 00b8c649b82c3..bab606c3a8169 100644 --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -287,18 +287,18 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func.func @linalg_generic_update_all_function_inputs_outputs( - // CHECK-SAME: %[[VAL_0:.*]]: memref<3x4xf32, 1>, - // CHECK-SAME: %[[VAL_1:.*]]: memref<3x4xf32, 1>) -> memref<3x4xf32, 1> { -func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf32, 1>, %arg1: memref<3x4xf32, 1>) -> memref<3x4xf32, 1> { - // CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x4xf32, 1> - // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> - // CHECK: %[[VAL_4:.*]] = memref.subview %[[VAL_1]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> - // CHECK: %[[VAL_5:.*]] = memref.subview %[[VAL_2]][0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> - - %alloc = memref.alloc() {alignment = 64 : i64} : memref<3x4xf32, 1> - %subview = memref.subview %arg0[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> - %subview_0 = memref.subview %arg1[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> - %subview_1 = memref.subview %alloc[0, 0] [4, 3] [1, 1] : memref<3x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + // CHECK-SAME: %[[VAL_0:.*]]: memref<8x4xf32, 1>, + // CHECK-SAME: %[[VAL_1:.*]]: memref<8x4xf32, 1>) -> memref<8x4xf32, 1> { +func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<8x4xf32, 1>, %arg1: memref<8x4xf32, 1>) -> memref<8x4xf32, 1> { + // CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x4xf32, 1> + // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]][0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + // CHECK: %[[VAL_4:.*]] = memref.subview %[[VAL_1]][0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + // CHECK: %[[VAL_5:.*]] = memref.subview %[[VAL_2]][0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4xf32, 1> + %subview = memref.subview %arg0[0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + %subview_0 = memref.subview %arg1[0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> + %subview_1 = memref.subview %alloc[0, 0] [4, 3] [1, 1] : memref<8x4xf32, 1> to memref<4x3xf32, strided<[4, 1]>, 1> // CHECK: %[[VAL_6:.*]] = arith.constant 0 : index // CHECK: %[[VAL_7:.*]] = arith.constant 4 : index @@ -376,10 +376,10 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf // CHECK: memref.dealloc %[[VAL_22]] : memref<48xi8, #gpu.address_space> // CHECK: memref.dealloc %[[VAL_41]] : memref<48xi8, #gpu.address_space> // CHECK: memref.dealloc %[[VAL_60]] : memref<48xi8, #gpu.address_space> - // CHECK: return %[[VAL_2]] : memref<3x4xf32, 1> + // CHECK: return %[[VAL_2]] : memref<8x4xf32, 1> // CHECK: } - return %alloc : memref<3x4xf32, 1> + return %alloc : memref<8x4xf32, 1> } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 02110bc2892d0..5d8a7d3f64e8f 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -635,9 +635,9 @@ func.func @fold_no_op_subview(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, st // ----- -func.func @no_fold_subview_with_non_zero_offset(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 1], offset: 1>> { - %0 = memref.subview %arg0[0, 1] [20, 42] [1, 1] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 1], offset: 1>> - return %0 : memref<20x42xf32, strided<[42, 1], offset: 1>> +func.func @no_fold_subview_with_non_zero_offset(%arg0 : memref<20x42xf32>) -> memref<20x41xf32, strided<[42, 1], offset: 1>> { + %0 = memref.subview %arg0[0, 1] [20, 41] [1, 1] : memref<20x42xf32> to memref<20x41xf32, strided<[42, 1], offset: 1>> + return %0 : memref<20x41xf32, strided<[42, 1], offset: 1>> } // CHECK-LABEL: func @no_fold_subview_with_non_zero_offset( // CHECK: %[[SUBVIEW:.+]] = memref.subview @@ -645,9 +645,9 @@ func.func @no_fold_subview_with_non_zero_offset(%arg0 : memref<20x42xf32>) -> me // ----- -func.func @no_fold_subview_with_non_unit_stride(%arg0 : memref<20x42xf32>) -> memref<20x42xf32, strided<[42, 2]>> { - %0 = memref.subview %arg0[0, 0] [20, 42] [1, 2] : memref<20x42xf32> to memref<20x42xf32, strided<[42, 2]>> - return %0 : memref<20x42xf32, strided<[42, 2]>> +func.func @no_fold_subview_with_non_unit_stride(%arg0 : memref<20x42xf32>) -> memref<20x5xf32, strided<[42, 2]>> { + %0 = memref.subview %arg0[0, 0] [20, 5] [1, 2] : memref<20x42xf32> to memref<20x5xf32, strided<[42, 2]>> + return %0 : memref<20x5xf32, strided<[42, 2]>> } // CHECK-LABEL: func @no_fold_subview_with_non_unit_stride( // CHECK: %[[SUBVIEW:.+]] = memref.subview @@ -655,6 +655,16 @@ func.func @no_fold_subview_with_non_unit_stride(%arg0 : memref<20x42xf32>) -> me // ----- +// CHECK-LABEL: func @no_fold_invalid_dynamic_slice +// CHECK: memref.subview %arg0[2] [%{{.*}}] [1] : memref<10xf32> to memref> +func.func @no_fold_invalid_dynamic_slice(%arg0: memref<10xf32>) -> memref> { + %c11 = arith.constant 11 : index + %0 = memref.subview %arg0 [2][%c11][1] : memref<10xf32> to memref> + func.return %0 : memref> +} + +// ----- + func.func @no_fold_dynamic_no_op_subview(%arg0 : memref) -> memref> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 647731db439c0..1e6b0111fa4c7 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -119,39 +119,39 @@ func.func @extract_strided_metadata_of_subview(%base: memref<5x4xf32>) // when dynamic sizes are involved. // See extract_strided_metadata_of_subview for an explanation of the actual // expansion. -// Orig strides: [64, 4, 1] +// Orig strides: [384, 24, 1] // Sub strides: [1, 1, 1] -// => New strides: [64, 4, 1] +// => New strides: [384, 24, 1] // // Orig offset: 0 // Sub offsets: [3, 4, 2] -// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 +// => Final offset: 3 * 384 + 4 * 24 + 2 * 1 + 0 == 1250 // // Final sizes == subview sizes == [%size, 6, 3] // // CHECK-LABEL: func @extract_strided_metadata_of_subview_with_dynamic_size -// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x24xf32>, // CHECK-SAME: %[[DYN_SIZE:.*]]: index) // -// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index -// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C1250:.*]] = arith.constant 1250 : index +// CHECK-DAG: %[[C384:.*]] = arith.constant 384 : index // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C24:.*]] = arith.constant 24 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // // CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] // -// CHECK: return %[[BASE]], %[[C210]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[C64]], %[[C4]], %[[C1]] +// CHECK: return %[[BASE]], %[[C1250]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[C384]], %[[C24]], %[[C1]] func.func @extract_strided_metadata_of_subview_with_dynamic_size( - %base: memref<8x16x4xf32>, %size: index) + %base: memref<8x16x24xf32>, %size: index) -> (memref, index, index, index, index, index, index, index) { %subview = memref.subview %base[3, 4, 2][%size, 6, 3][1, 1, 1] : - memref<8x16x4xf32> to memref> + memref<8x16x24xf32> to memref> %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : - memref> + memref> -> memref, index, index, index, index, index, index, index return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : @@ -167,37 +167,37 @@ func.func @extract_strided_metadata_of_subview_with_dynamic_size( // See extract_strided_metadata_of_subview for an explanation of the actual // expansion. // -// Orig strides: [64, 4, 1] +// Orig strides: [384, 24, 1] // Sub strides: [1, 1, 1] -// => New strides: [64, 4, 1] -// Final strides == filterOutReducedDim(new strides, 0) == [4 , 1] +// => New strides: [384, 24, 1] +// Final strides == filterOutReducedDim(new strides, 0) == [24 , 1] // // Orig offset: 0 // Sub offsets: [3, 4, 2] -// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 +// => Final offset: 3 * 384 + 4 * 24 + 2 * 1 + 0 == 1250 // // Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3] // // CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview -// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>) +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x24xf32>) // -// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index +// CHECK-DAG: %[[C1250:.*]] = arith.constant 1250 : index // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C24:.*]] = arith.constant 24 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // // CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] // -// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[C4]], %[[C1]] -func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4xf32>) +// CHECK: return %[[BASE]], %[[C1250]], %[[C6]], %[[C3]], %[[C24]], %[[C1]] +func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x24xf32>) -> (memref, index, index, index, index, index) { %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, 1, 1] : - memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + memref<8x16x24xf32> to memref<6x3xf32, strided<[24, 1], offset: 1250>> %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : - memref<6x3xf32, strided<[4,1], offset: 210>> + memref<6x3xf32, strided<[24, 1], offset: 1250>> -> memref, index, index, index, index, index return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : @@ -215,21 +215,21 @@ func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4 // See extract_strided_metadata_of_subview for an explanation of the actual // expansion. // -// Orig strides: [64, 4, 1] +// Orig strides: [384, 24, 1] // Sub strides: [1, %stride, 1] -// => New strides: [64, 4 * %stride, 1] -// Final strides == filterOutReducedDim(new strides, 0) == [4 * %stride , 1] +// => New strides: [384, 24 * %stride, 1] +// Final strides == filterOutReducedDim(new strides, 0) == [24 * %stride , 1] // // Orig offset: 0 // Sub offsets: [3, 4, 2] -// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 +// => Final offset: 3 * 384 + 4 * 24 + 2 * 1 + 0 == 1250 // -// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 24)> // CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides -// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x24xf32>, // CHECK-SAME: %[[DYN_STRIDE:.*]]: index) // -// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index +// CHECK-DAG: %[[C1250:.*]] = arith.constant 1250 : index // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -238,16 +238,16 @@ func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4 // // CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]] // -// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] +// CHECK: return %[[BASE]], %[[C1250]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides( - %base: memref<8x16x4xf32>, %stride: index) + %base: memref<8x16x24xf32>, %stride: index) -> (memref, index, index, index, index, index) { %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] : - memref<8x16x4xf32> to memref<6x3xf32, strided<[?, 1], offset: 210>> + memref<8x16x24xf32> to memref<6x3xf32, strided<[?, 1], offset: 1250>> %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : - memref<6x3xf32, strided<[?, 1], offset: 210>> + memref<6x3xf32, strided<[?, 1], offset: 1250>> -> memref, index, index, index, index, index return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 327cacf7d9a20..067cdb5c5fd20 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -632,17 +632,17 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)> // CHECK-LABEL: func @subview_of_subview( -// CHECK-SAME: %[[m:.*]]: memref<1x1024xf32, 3>, %[[pos:.*]]: index +// CHECK-SAME: %[[m:.*]]: memref<8x1024xf32, 3>, %[[pos:.*]]: index // CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%arg1] -// CHECK: memref.subview %arg0[4, %[[add]]] [1, 1] [1, 1] : memref<1x1024xf32, 3> to memref, 3> -func.func @subview_of_subview(%m: memref<1x1024xf32, 3>, %pos: index) +// CHECK: memref.subview %arg0[4, %[[add]]] [1, 1] [1, 1] : memref<8x1024xf32, 3> to memref, 3> +func.func @subview_of_subview(%m: memref<8x1024xf32, 3>, %pos: index) -> memref, 3> { - %0 = memref.subview %m[3, %pos] [1, 2] [1, 1] - : memref<1x1024xf32, 3> - to memref<1x2xf32, strided<[1024, 1], offset: ?>, 3> + %0 = memref.subview %m[3, %pos] [5, 7] [1, 1] + : memref<8x1024xf32, 3> + to memref<5x7xf32, strided<[1024, 1], offset: ?>, 3> %1 = memref.subview %0[1, 2] [1, 1] [1, 1] - : memref<1x2xf32, strided<[1024, 1], offset: ?>, 3> + : memref<5x7xf32, strided<[1024, 1], offset: ?>, 3> to memref, 3> return %1 : memref, 3> } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index f72ad48245f81..34fc4775924e7 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -723,6 +723,22 @@ func.func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { // ----- +func.func @invalid_subview(%arg0: memref<10xf32>) { + // expected-error@+1 {{offset 0 is out-of-bounds: 10 >= 10}} + %0 = memref.subview %arg0 [10][1][1] : memref<10xf32> to memref<1xf32, strided<[1], offset: 10>> + return +} + +// ----- + +func.func @invalid_subview(%arg0: memref<9xf32>) { + // expected-error@+1 {{slice along dimension 0 runs out-of-bounds: 9 >= 9}} + %0 = memref.subview %arg0 [3][4][2] : memref<9xf32> to memref<4xf32, strided<[2], offset: 3>> + return +} + +// ----- + func.func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = memref.alloc() : memref<8x16x4xf32> // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, strided<[64, 4, 1]>>' or a rank-reduced version. (mismatch of result sizes)}} diff --git a/mlir/test/Dialect/MemRef/subview.mlir b/mlir/test/Dialect/MemRef/subview.mlir index 135a1124066e4..fd8aaaf86b2d8 100644 --- a/mlir/test/Dialect/MemRef/subview.mlir +++ b/mlir/test/Dialect/MemRef/subview.mlir @@ -90,7 +90,7 @@ func.func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK: memref.subview %{{.*}}[0, 0, 0] [1, 16, 4] [1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> %21 = memref.subview %20[0, 0, 0][1, 16, 4][1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> - %22 = memref.subview %20[3, 4, 2][1, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + %22 = memref.subview %20[3, 4, 1][1, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 209>> %23 = memref.alloc() : memref %78 = memref.subview %23[] [] [] : memref to memref diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 9b74362b6ee75..8e02c06a0a293 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -752,83 +752,83 @@ func.func @subview(%arg0 : index, %arg1 : index) -> (index, index) { %c15 = arith.constant 15 : index // CHECK: %[[ALLOC0:.*]] = memref.alloc() - %0 = memref.alloc() : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> + %0 = memref.alloc() : memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> // Test: subview with constant base memref and constant operands is folded. // Note that the subview uses the base memrefs layout map because it used // zero offset and unit stride arguments. // CHECK: memref.subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [1, 1, 1] : - // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> - // CHECK-SAME: to memref<7x11x2xf32, strided<[64, 4, 1]>> + // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> + // CHECK-SAME: to memref<7x11x2xf32, strided<[6144, 64, 1]>> %1 = memref.subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] - : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + : memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to memref> %v0 = memref.load %1[%c0, %c0, %c0] : memref> // Test: subview with one dynamic operand can also be folded. // CHECK: memref.subview %[[ALLOC0]][0, %[[ARG0]], 0] [7, 11, 15] [1, 1, 1] : - // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> - // CHECK-SAME: to memref<7x11x15xf32, strided<[64, 4, 1], offset: ?>> + // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> + // CHECK-SAME: to memref<7x11x15xf32, strided<[6144, 64, 1], offset: ?>> %2 = memref.subview %0[%c0, %arg0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] - : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + : memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to memref> memref.store %v0, %2[%c0, %c0, %c0] : memref> // CHECK: %[[ALLOC1:.*]] = memref.alloc(%[[ARG0]]) %3 = memref.alloc(%arg0) : memref> // Test: subview with constant operands but dynamic base memref is folded as long as the strides and offset of the base memref are static. - // CHECK: memref.subview %[[ALLOC1]][0, 0, 0] [7, 11, 15] [1, 1, 1] : + // CHECK: memref.subview %[[ALLOC1]][0, 0, 0] [7, 11, 2] [1, 1, 1] : // CHECK-SAME: memref> - // CHECK-SAME: to memref<7x11x15xf32, strided<[64, 4, 1]>> - %4 = memref.subview %3[%c0, %c0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] + // CHECK-SAME: to memref<7x11x2xf32, strided<[64, 4, 1]>> + %4 = memref.subview %3[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] : memref> to memref> memref.store %v0, %4[%c0, %c0, %c0] : memref> // Test: subview offset operands are folded correctly w.r.t. base strides. // CHECK: memref.subview %[[ALLOC0]][1, 2, 7] [7, 11, 2] [1, 1, 1] : - // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to - // CHECK-SAME: memref<7x11x2xf32, strided<[64, 4, 1], offset: 79>> + // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> to + // CHECK-SAME: memref<7x11x2xf32, strided<[6144, 64, 1], offset: 6279>> %5 = memref.subview %0[%c1, %c2, %c7] [%c7, %c11, %c2] [%c1, %c1, %c1] - : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + : memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to memref> memref.store %v0, %5[%c0, %c0, %c0] : memref> // Test: subview stride operands are folded correctly w.r.t. base strides. // CHECK: memref.subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [2, 7, 11] : - // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> - // CHECK-SAME: to memref<7x11x2xf32, strided<[128, 28, 11]>> + // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> + // CHECK-SAME: to memref<7x11x2xf32, strided<[12288, 448, 11]>> %6 = memref.subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c2, %c7, %c11] - : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + : memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to memref> memref.store %v0, %6[%c0, %c0, %c0] : memref> // Test: subview shape are folded, but offsets and strides are not even if base memref is static // CHECK: memref.subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : - // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to + // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> to // CHECK-SAME: memref<7x11x2xf32, strided<[?, ?, ?], offset: ?>> %10 = memref.subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : - memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to memref> memref.store %v0, %10[%arg1, %arg1, %arg1] : memref> // Test: subview strides are folded, but offsets and shape are not even if base memref is static // CHECK: memref.subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 7, 11] : - // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to - // CHECK-SAME: memref> + // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> to + // CHECK-SAME: memref> %11 = memref.subview %0[%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] [%c2, %c7, %c11] : - memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to memref> memref.store %v0, %11[%arg0, %arg0, %arg0] : memref> // Test: subview offsets are folded, but strides and shape are not even if base memref is static // CHECK: memref.subview %[[ALLOC0]][1, 2, 7] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] : - // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to - // CHECK-SAME: memref> + // CHECK-SAME: memref<128x96x64xf32, strided<[6144, 64, 1]>> to + // CHECK-SAME: memref> %13 = memref.subview %0[%c1, %c2, %c7] [%arg1, %arg1, %arg1] [%arg0, %arg0, %arg0] : - memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref<128x96x64xf32, strided<[6144, 64, 1], offset: 0>> to memref> memref.store %v0, %13[%arg1, %arg1, %arg1] : memref> @@ -862,27 +862,27 @@ func.func @subview(%arg0 : index, %arg1 : index) -> (index, index) { memref> memref.store %v0, %17[%arg0, %arg0, %arg0] : memref> - // CHECK: %[[ALLOC3:.*]] = memref.alloc() : memref<12x4xf32> - %18 = memref.alloc() : memref<12x4xf32> + // CHECK: %[[ALLOC3:.*]] = memref.alloc() : memref<128x64xf32> + %18 = memref.alloc() : memref<128x64xf32> %c4 = arith.constant 4 : index // TEST: subview strides are maintained when sizes are folded // CHECK: memref.subview %[[ALLOC3]][%arg1, %arg1] [2, 4] [1, 1] : - // CHECK-SAME: memref<12x4xf32> to - // CHECK-SAME: memref<2x4xf32, strided<[4, 1], offset: ?> + // CHECK-SAME: memref<128x64xf32> to + // CHECK-SAME: memref<2x4xf32, strided<[64, 1], offset: ?> %19 = memref.subview %18[%arg1, %arg1] [%c2, %c4] [1, 1] : - memref<12x4xf32> to - memref> - memref.store %v0, %19[%arg1, %arg1] : memref> + memref<128x64xf32> to + memref> + memref.store %v0, %19[%arg1, %arg1] : memref> // TEST: subview strides and sizes are maintained when offsets are folded // CHECK: memref.subview %[[ALLOC3]][2, 4] [12, 4] [1, 1] : - // CHECK-SAME: memref<12x4xf32> to - // CHECK-SAME: memref<12x4xf32, strided<[4, 1], offset: 12>> + // CHECK-SAME: memref<128x64xf32> to + // CHECK-SAME: memref<12x4xf32, strided<[64, 1], offset: 132>> %20 = memref.subview %18[%c2, %c4] [12, 4] [1, 1] : - memref<12x4xf32> to - memref<12x4xf32, strided<[4, 1], offset: ?>> - memref.store %v0, %20[%arg1, %arg1] : memref<12x4xf32, strided<[4, 1], offset: ?>> + memref<128x64xf32> to + memref<12x4xf32, strided<[64, 1], offset: ?>> + memref.store %v0, %20[%arg1, %arg1] : memref<12x4xf32, strided<[64, 1], offset: ?>> // Test: dim on subview is rewritten to size operand. %7 = memref.dim %4, %c0 : memref> diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir index 22ffd836c68ed..53fbb8a356def 100644 --- a/mlir/test/Transforms/compose-subview.mlir +++ b/mlir/test/Transforms/compose-subview.mlir @@ -53,10 +53,10 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { -func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> - %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>> +// CHECK-SAME: %[[VAL_0:.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { +func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { + // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> + %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<8x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>> %1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>> }