diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 3e99c1f717d09..ea37a76360ed2 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1054,18 +1054,40 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { return SplatElementsAttr::get(outputTy, operand.getSplatValue()); } - if (inputTy.hasStaticShape() && outputTy.hasStaticShape() && - outputTy.getNumElements() == 1) { - DenseElementsAttr startElems; - if (!matchPattern(getStart(), m_Constant(&startElems))) - return {}; + if (!inputTy.hasStaticShape() || !outputTy.hasStaticShape()) + return {}; + + DenseElementsAttr startElems; + if (!matchPattern(getStart(), m_Constant(&startElems))) + return {}; - llvm::SmallVector indices = - llvm::to_vector(startElems.getValues()); + auto indices = llvm::to_vector(startElems.getValues()); + + if (outputTy.getNumElements() == 1) { auto value = operand.getValues()[indices]; return SplatElementsAttr::get(outputTy, value); } + DenseElementsAttr size_elems; + if (!matchPattern(getSize(), m_Constant(&size_elems))) + return {}; + + const auto sizes = llvm::to_vector(size_elems.getValues()); + + // Fold slice when all operands are constant and the output is 'small' + // A 'small' output is currently defined as 1D and <= 6 elements + // (tosa_level_8k MAX_RANK) + if (inputTy.getRank() == 1 && outputTy.getRank() == 1 && + outputTy.getNumElements() <= 6 && indices.size() == 1 && + sizes.size() == 1) { + const auto begin = operand.value_begin(); + const uint64_t offset = indices[0]; + const uint64_t size = sizes[0]; + const SmallVector slicedValues(begin + offset, + begin + offset + size); + return DenseElementsAttr::get(outputTy, slicedValues); + } + return {}; } diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir index 9b6ccdb54c107..3c1f2c5058b95 100644 --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -21,3 +21,54 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> return } + +// ----- + +// CHECK-LABEL: test_1d_slice +func.func @test_1d_slice() -> tensor<6xi32> { + // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>}> : () -> tensor<6xi32> + // CHECK: return %[[VAL_0]] : tensor<6xi32> + %0 = "tosa.const"() <{values = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32> + %1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> + %2 = tosa.const_shape {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1> + %3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32> + return %3 : tensor<6xi32> +} + +// ----- + +// CHECK-LABEL: test_1d_slice_non_const_input +func.func @test_1d_slice_non_const_input(%arg0 : tensor<10xi32>) -> tensor<6xi32> { + // check that slice is not folded for non-constant input1 + // CHECK: tosa.slice + %1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> + %2 = tosa.const_shape {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1> + %3 = tosa.slice %arg0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32> + return %3 : tensor<6xi32> +} + +// ----- + +// CHECK-LABEL: test_1d_slice_rank_2_input +func.func @test_1d_slice_rank_2_input(%arg0 : tensor<1x10xi32>) -> tensor<1x6xi32> { + // check that slice is not folded for input1 rank > 1 + // CHECK: tosa.slice + %0 = "tosa.const"() <{values = dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]> : tensor<1x10xi32>}> : () -> tensor<1x10xi32> + %1 = tosa.const_shape {values = dense<[0, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %2 = tosa.const_shape {values = dense<[1, 6]> : tensor<2xindex>} : () -> !tosa.shape<2> + %3 = tosa.slice %arg0, %1, %2 : (tensor<1x10xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x6xi32> + return %3 : tensor<1x6xi32> +} + +// ----- + +// CHECK-LABEL: test_1d_slice_more_than_6 +func.func @test_1d_slice_more_than_6() -> tensor<7xi32> { + // check that slice is not folded because output has more than 6 elements + // CHECK: tosa.slice + %0 = "tosa.const"() <{values = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32> + %1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> + %2 = tosa.const_shape {values = dense<7> : tensor<1xindex>} : () -> !tosa.shape<1> + %3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<7xi32> + return %3 : tensor<7xi32> +}