Skip to content

Commit 1f2e303

Browse files
lhutton1Jerry-Ge
authored andcommitted
[mlir][tosa] Fold 'small' constant 1D concat operations
The commit improves the concat folder to cover operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements. This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK). Change-Id: Ieb522fc1d0d1ec4596ce060aa9ab76439322d6d6 Signed-off-by: Luke Hutton <[email protected]>
1 parent 8f4ee42 commit 1f2e303

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,16 +1210,51 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
12101210
}
12111211

12121212
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1213+
auto operands = getOperands();
1214+
const unsigned int numOperands = getNumOperands();
1215+
1216+
// Fold concat when all operands are constant and the output is 'small'
1217+
if (llvm::all_of(operands, [](Value v) {
1218+
return llvm::dyn_cast_or_null<tosa::ConstOp>(v.getDefiningOp());
1219+
})) {
1220+
const ShapedType outputType = dyn_cast<ShapedType>(getOutput().getType());
1221+
if (!outputType || !outputType.hasStaticShape())
1222+
return {};
1223+
1224+
// A 'small' output is currently defined as 1D and <= 6 elements
1225+
// (tosa_level_8k MAX_RANK)
1226+
if (outputType.getRank() != 1)
1227+
return {};
1228+
1229+
const int64_t outputNumElements = outputType.getNumElements();
1230+
if (outputNumElements > 6)
1231+
return {};
1232+
1233+
llvm::SmallVector<Attribute> constOperands;
1234+
constOperands.reserve(outputNumElements);
1235+
for (const Attribute &operand : adaptor.getOperands()) {
1236+
const auto elementsAttr =
1237+
llvm::dyn_cast_if_present<DenseElementsAttr>(operand);
1238+
if (!elementsAttr)
1239+
return {};
1240+
1241+
constOperands.append(
1242+
llvm::to_vector(elementsAttr.getValues<Attribute>()));
1243+
}
1244+
1245+
return DenseElementsAttr::get(outputType, constOperands);
1246+
}
1247+
12131248
// Fold consecutive concats on the same axis into a single op.
12141249
// Keep track of the operands so we are able to construct a new concat
12151250
// later. Conservatively assume that we double the number of operands when
12161251
// folding
12171252
SmallVector<Value, 8> concatOperands;
1218-
concatOperands.reserve(2 * getNumOperands());
1253+
concatOperands.reserve(2 * numOperands);
12191254

12201255
// Find all operands that are foldable concats
12211256
bool foundFoldableConcat = false;
1222-
for (Value operand : getOperands()) {
1257+
for (Value operand : operands) {
12231258
concatOperands.emplace_back(operand);
12241259

12251260
auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());

mlir/test/Dialect/Tosa/fold_concats.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,58 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x
9191
// CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
9292
// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32>
9393
// CHECK: }
94+
95+
// -----
96+
97+
// CHECK-LABEL: test_fold_small_const_concat
98+
func.func @test_fold_small_const_concat() -> tensor<6xi8> {
99+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi8>}> : () -> tensor<6xi8>
100+
// CHECK: return %[[VAL_0]] : tensor<6xi8>
101+
%0 = "tosa.const"() <{value = dense<[1, 2]> : tensor<2xi8>}> : () -> tensor<2xi8>
102+
%1 = "tosa.const"() <{value = dense<[3, 4, 5]> : tensor<3xi8>}> : () -> tensor<3xi8>
103+
%2 = "tosa.const"() <{value = dense<6> : tensor<1xi8>}> : () -> tensor<1xi8>
104+
%3 = "tosa.concat"(%0, %1, %2) <{axis = 0 : i32}> : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8>
105+
func.return %3 : tensor<6xi8>
106+
}
107+
108+
// -----
109+
110+
// CHECK-LABEL: test_no_fold_small_const_concat_with_non_const
111+
func.func @test_no_fold_small_const_concat_with_non_const(%arg0: tensor<2xi8>, %arg1: tensor<3xi8>, %arg2: tensor<1xi8>) -> tensor<6xi8> {
112+
// CHECK: %[[VAL_3:.*]] = tosa.concat %arg0, %arg1, %arg2 {axis = 0 : i32} : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8>
113+
// CHECK: return %[[VAL_3]] : tensor<6xi8>
114+
%1 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 0 : i32}> : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8>
115+
func.return %1 : tensor<6xi8>
116+
}
117+
118+
// -----
119+
120+
// CHECK-LABEL: test_no_fold_small_const_concat_with_higher_dim
121+
func.func @test_no_fold_small_const_concat_with_higher_dim() -> tensor<7xi8> {
122+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3xi8>
123+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[4, 5, 6]> : tensor<3xi8>}> : () -> tensor<3xi8>
124+
// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
125+
// CHECK-DAG: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] {axis = 0 : i32} : (tensor<3xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<7xi8>
126+
// CHECK: return %[[VAL_3]] : tensor<7xi8>
127+
%0 = "tosa.const"() <{value = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3xi8>
128+
%1 = "tosa.const"() <{value = dense<[4, 5, 6]> : tensor<3xi8>}> : () -> tensor<3xi8>
129+
%2 = "tosa.const"() <{value = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
130+
%3 = "tosa.concat"(%0, %1, %2) <{axis = 0 : i32}> : (tensor<3xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<7xi8>
131+
func.return %3 : tensor<7xi8>
132+
}
133+
134+
// -----
135+
136+
// CHECK-LABEL: test_no_fold_small_const_concat_with_higher_rank
137+
func.func @test_no_fold_small_const_concat_with_higher_rank() -> tensor<1x6xi8> {
138+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8>
139+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}3, 4, 5]]> : tensor<1x3xi8>}> : () -> tensor<1x3xi8>
140+
// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<6> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
141+
// CHECK-DAG: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2xi8>, tensor<1x3xi8>, tensor<1x1xi8>) -> tensor<1x6xi8>
142+
// CHECK: return %[[VAL_3]] : tensor<1x6xi8>
143+
%0 = "tosa.const"() <{value = dense<[[1, 2]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8>
144+
%1 = "tosa.const"() <{value = dense<[[3, 4, 5]]> : tensor<1x3xi8>}> : () -> tensor<1x3xi8>
145+
%2 = "tosa.const"() <{value = dense<[[6]]> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
146+
%3 = "tosa.concat"(%0, %1, %2) <{axis = 1 : i32}> : (tensor<1x2xi8>, tensor<1x3xi8>, tensor<1x1xi8>) -> tensor<1x6xi8>
147+
func.return %3 : tensor<1x6xi8>
148+
}

0 commit comments

Comments
 (0)