Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1186,16 +1186,51 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
auto operands = getOperands();
const unsigned int numOperands = getNumOperands();

// Fold concat when all operands are constant and the output is 'small'
if (llvm::all_of(operands, [](Value v) {
return llvm::dyn_cast_or_null<tosa::ConstOp>(v.getDefiningOp());
})) {
const ShapedType outputType = dyn_cast<ShapedType>(getOutput().getType());
if (!outputType || !outputType.hasStaticShape())
return {};

// A 'small' output is currently defined as 1D and <= 6 elements
// (tosa_level_8k MAX_RANK)
if (outputType.getRank() != 1)
return {};

const int64_t outputNumElements = outputType.getNumElements();
if (outputNumElements > 6)
return {};

llvm::SmallVector<Attribute> constOperands;
constOperands.reserve(outputNumElements);
for (const Attribute &operand : adaptor.getOperands()) {
const auto elementsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(operand);
if (!elementsAttr)
return {};

constOperands.append(
llvm::to_vector(elementsAttr.getValues<Attribute>()));
}

return DenseElementsAttr::get(outputType, constOperands);
}

// Fold consecutive concats on the same axis into a single op.
// Keep track of the operands so we are able to construct a new concat
// later. Conservatively assume that we double the number of operands when
// folding
SmallVector<Value, 8> concatOperands;
concatOperands.reserve(2 * getNumOperands());
concatOperands.reserve(2 * numOperands);

// Find all operands that are foldable concats
bool foundFoldableConcat = false;
for (Value operand : getOperands()) {
for (Value operand : operands) {
concatOperands.emplace_back(operand);

auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
Expand Down
55 changes: 55 additions & 0 deletions mlir/test/Dialect/Tosa/fold_concats.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,58 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x
// CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32>
// CHECK: }

// -----

// CHECK-LABEL: test_fold_small_const_concat
func.func @test_fold_small_const_concat() -> tensor<6xi8> {
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi8>}> : () -> tensor<6xi8>
// CHECK: return %[[VAL_0]] : tensor<6xi8>
%0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi8>}> : () -> tensor<2xi8>
%1 = "tosa.const"() <{values = dense<[3, 4, 5]> : tensor<3xi8>}> : () -> tensor<3xi8>
%2 = "tosa.const"() <{values = dense<6> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.concat"(%0, %1, %2) <{axis = 0 : i32}> : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8>
func.return %3 : tensor<6xi8>
}

// -----

// CHECK-LABEL: test_no_fold_small_const_concat_with_non_const
func.func @test_no_fold_small_const_concat_with_non_const(%arg0: tensor<2xi8>, %arg1: tensor<3xi8>, %arg2: tensor<1xi8>) -> tensor<6xi8> {
// CHECK: %[[VAL_3:.*]] = tosa.concat %arg0, %arg1, %arg2 {axis = 0 : i32} : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8>
// CHECK: return %[[VAL_3]] : tensor<6xi8>
%1 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 0 : i32}> : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8>
func.return %1 : tensor<6xi8>
}

// -----

// CHECK-LABEL: test_no_fold_small_const_concat_with_higher_dim
func.func @test_no_fold_small_const_concat_with_higher_dim() -> tensor<7xi8> {
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3xi8>
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<[4, 5, 6]> : tensor<3xi8>}> : () -> tensor<3xi8>
// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK-DAG: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] {axis = 0 : i32} : (tensor<3xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<7xi8>
// CHECK: return %[[VAL_3]] : tensor<7xi8>
%0 = "tosa.const"() <{values = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3xi8>
%1 = "tosa.const"() <{values = dense<[4, 5, 6]> : tensor<3xi8>}> : () -> tensor<3xi8>
%2 = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.concat"(%0, %1, %2) <{axis = 0 : i32}> : (tensor<3xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<7xi8>
func.return %3 : tensor<7xi8>
}

// -----

// CHECK-LABEL: test_no_fold_small_const_concat_with_higher_rank
func.func @test_no_fold_small_const_concat_with_higher_rank() -> tensor<1x6xi8> {
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8>
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}3, 4, 5]]> : tensor<1x3xi8>}> : () -> tensor<1x3xi8>
// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<6> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
// CHECK-DAG: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2xi8>, tensor<1x3xi8>, tensor<1x1xi8>) -> tensor<1x6xi8>
// CHECK: return %[[VAL_3]] : tensor<1x6xi8>
%0 = "tosa.const"() <{values = dense<[[1, 2]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8>
%1 = "tosa.const"() <{values = dense<[[3, 4, 5]]> : tensor<1x3xi8>}> : () -> tensor<1x3xi8>
%2 = "tosa.const"() <{values = dense<[[6]]> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
%3 = "tosa.concat"(%0, %1, %2) <{axis = 1 : i32}> : (tensor<1x2xi8>, tensor<1x3xi8>, tensor<1x1xi8>) -> tensor<1x6xi8>
func.return %3 : tensor<1x6xi8>
}