Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,18 +1054,40 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}

if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
outputTy.getNumElements() == 1) {
DenseElementsAttr startElems;
if (!matchPattern(getStart(), m_Constant(&startElems)))
return {};
if (!inputTy.hasStaticShape() || !outputTy.hasStaticShape())
return {};

DenseElementsAttr startElems;
if (!matchPattern(getStart(), m_Constant(&startElems)))
return {};

llvm::SmallVector<uint64_t> indices =
llvm::to_vector(startElems.getValues<uint64_t>());
auto indices = llvm::to_vector(startElems.getValues<uint64_t>());

if (outputTy.getNumElements() == 1) {
auto value = operand.getValues<Attribute>()[indices];
return SplatElementsAttr::get(outputTy, value);
}

DenseElementsAttr size_elems;
if (!matchPattern(getSize(), m_Constant(&size_elems)))
return {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we put a newline after this return, it makes it a wee bit easier to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


const auto sizes = llvm::to_vector(size_elems.getValues<uint64_t>());

// Fold slice when all operands are constant and the output is 'small'
// A 'small' output is currently defined as 1D and <= 6 elements
// (tosa_level_8k MAX_RANK)
if (inputTy.getRank() == 1 && outputTy.getRank() == 1 &&
outputTy.getNumElements() <= 6 && indices.size() == 1 &&
sizes.size() == 1) {
const auto begin = operand.value_begin<Attribute>();
const uint64_t offset = indices[0];
const uint64_t size = sizes[0];
const SmallVector<Attribute> slicedValues(begin + offset,
begin + offset + size);
return DenseElementsAttr::get(outputTy, slicedValues);
}

return {};
}

Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Dialect/Tosa/constant_folding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,54 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
return
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some negative tests here as well?

The condition for this transformation is:


  // Fold slice when all operands are constant and the output is 'small'
  // A 'small' output is currently defined as 1D and <= 6 elements
  // (tosa_level_8k MAX_RANK)

So this suggests the following negative cases:

  1. 1 or more non-const operands
  2. Tensor of rank > 1
  3. > 6 element in extracted tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// -----

// CHECK-LABEL: test_1d_slice
func.func @test_1d_slice() -> tensor<6xi32> {
// CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>}> : () -> tensor<6xi32>
// CHECK: return %[[VAL_0]] : tensor<6xi32>
%0 = "tosa.const"() <{values = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32>
%1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
%2 = tosa.const_shape {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
%3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32>
return %3 : tensor<6xi32>
}

// -----

// CHECK-LABEL: test_1d_slice_non_const_input
func.func @test_1d_slice_non_const_input(%arg0 : tensor<10xi32>) -> tensor<6xi32> {
// check that slice is not folded for non-constant input1
// CHECK: tosa.slice
%1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
%2 = tosa.const_shape {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
%3 = tosa.slice %arg0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32>
return %3 : tensor<6xi32>
}

// -----

// CHECK-LABEL: test_1d_slice_rank_2_input
func.func @test_1d_slice_rank_2_input(%arg0 : tensor<1x10xi32>) -> tensor<1x6xi32> {
// check that slice is not folded for input1 rank > 1
// CHECK: tosa.slice
%0 = "tosa.const"() <{values = dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]> : tensor<1x10xi32>}> : () -> tensor<1x10xi32>
%1 = tosa.const_shape {values = dense<[0, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%2 = tosa.const_shape {values = dense<[1, 6]> : tensor<2xindex>} : () -> !tosa.shape<2>
%3 = tosa.slice %arg0, %1, %2 : (tensor<1x10xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x6xi32>
return %3 : tensor<1x6xi32>
}

// -----

// CHECK-LABEL: test_1d_slice_more_than_6
func.func @test_1d_slice_more_than_6() -> tensor<7xi32> {
// check that slice is not folded because output has more than 6 elements
// CHECK: tosa.slice
%0 = "tosa.const"() <{values = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32>
%1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
%2 = tosa.const_shape {values = dense<7> : tensor<1xindex>} : () -> !tosa.shape<1>
%3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<7xi32>
return %3 : tensor<7xi32>
}