Skip to content

Commit 7fd38bf

Browse files
lhutton1Tai78641
authored andcommitted
[mlir][tosa] Fold 'small' constant 1D slice operations
This commit extends the slice folder to fold constant slice operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements. This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK). Change-Id: I1e59e5919f8c2936e98788c5a9b44a691940b28a Signed-off-by: Luke Hutton <[email protected]>
1 parent d5cef39 commit 7fd38bf

File tree

2 files changed

+80
-7
lines changed

2 files changed

+80
-7
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,18 +1054,40 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
10541054
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
10551055
}
10561056

1057-
if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1058-
outputTy.getNumElements() == 1) {
1059-
DenseElementsAttr startElems;
1060-
if (!matchPattern(getStart(), m_Constant(&startElems)))
1061-
return {};
1057+
if (!inputTy.hasStaticShape() || !outputTy.hasStaticShape())
1058+
return {};
1059+
1060+
DenseElementsAttr startElems;
1061+
if (!matchPattern(getStart(), m_Constant(&startElems)))
1062+
return {};
10621063

1063-
llvm::SmallVector<uint64_t> indices =
1064-
llvm::to_vector(startElems.getValues<uint64_t>());
1064+
auto indices = llvm::to_vector(startElems.getValues<uint64_t>());
1065+
1066+
if (outputTy.getNumElements() == 1) {
10651067
auto value = operand.getValues<Attribute>()[indices];
10661068
return SplatElementsAttr::get(outputTy, value);
10671069
}
10681070

1071+
DenseElementsAttr size_elems;
1072+
if (!matchPattern(getSize(), m_Constant(&size_elems)))
1073+
return {};
1074+
1075+
const auto sizes = llvm::to_vector(size_elems.getValues<uint64_t>());
1076+
1077+
// Fold slice when all operands are constant and the output is 'small'
1078+
// A 'small' output is currently defined as 1D and <= 6 elements
1079+
// (tosa_level_8k MAX_RANK)
1080+
if (inputTy.getRank() == 1 && outputTy.getRank() == 1 &&
1081+
outputTy.getNumElements() <= 6 && indices.size() == 1 &&
1082+
sizes.size() == 1) {
1083+
const auto begin = operand.value_begin<Attribute>();
1084+
const uint64_t offset = indices[0];
1085+
const uint64_t size = sizes[0];
1086+
const SmallVector<Attribute> slicedValues(begin + offset,
1087+
begin + offset + size);
1088+
return DenseElementsAttr::get(outputTy, slicedValues);
1089+
}
1090+
10691091
return {};
10701092
}
10711093

mlir/test/Dialect/Tosa/constant_folding.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,54 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
2121
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
2222
return
2323
}
24+
25+
// -----
26+
27+
// CHECK-LABEL: test_1d_slice
28+
func.func @test_1d_slice() -> tensor<6xi32> {
29+
// CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>}> : () -> tensor<6xi32>
30+
// CHECK: return %[[VAL_0]] : tensor<6xi32>
31+
%0 = "tosa.const"() <{values = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32>
32+
%1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
33+
%2 = tosa.const_shape {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
34+
%3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32>
35+
return %3 : tensor<6xi32>
36+
}
37+
38+
// -----
39+
40+
// CHECK-LABEL: test_1d_slice_non_const_input
41+
func.func @test_1d_slice_non_const_input(%arg0 : tensor<10xi32>) -> tensor<6xi32> {
42+
// check that slice is not folded for non-constant input1
43+
// CHECK: tosa.slice
44+
%1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
45+
%2 = tosa.const_shape {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
46+
%3 = tosa.slice %arg0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32>
47+
return %3 : tensor<6xi32>
48+
}
49+
50+
// -----
51+
52+
// CHECK-LABEL: test_1d_slice_rank_2_input
53+
func.func @test_1d_slice_rank_2_input(%arg0 : tensor<1x10xi32>) -> tensor<1x6xi32> {
54+
// check that slice is not folded for input1 rank > 1
55+
// CHECK: tosa.slice
56+
%0 = "tosa.const"() <{values = dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]> : tensor<1x10xi32>}> : () -> tensor<1x10xi32>
57+
%1 = tosa.const_shape {values = dense<[0, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
58+
%2 = tosa.const_shape {values = dense<[1, 6]> : tensor<2xindex>} : () -> !tosa.shape<2>
59+
%3 = tosa.slice %arg0, %1, %2 : (tensor<1x10xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x6xi32>
60+
return %3 : tensor<1x6xi32>
61+
}
62+
63+
// -----
64+
65+
// CHECK-LABEL: test_1d_slice_more_than_6
66+
func.func @test_1d_slice_more_than_6() -> tensor<7xi32> {
67+
// check that slice is not folded because output has more than 6 elements
68+
// CHECK: tosa.slice
69+
%0 = "tosa.const"() <{values = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32>
70+
%1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
71+
%2 = tosa.const_shape {values = dense<7> : tensor<1xindex>} : () -> !tosa.shape<1>
72+
%3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<7xi32>
73+
return %3 : tensor<7xi32>
74+
}

0 commit comments

Comments
 (0)