diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 09d2c5d35263c..8c9e96d5cf784 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1186,16 +1186,51 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { } OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { + auto operands = getOperands(); + const unsigned int numOperands = getNumOperands(); + + // Fold concat when all operands are constant and the output is 'small' + if (llvm::all_of(operands, [](Value v) { + return llvm::dyn_cast_or_null(v.getDefiningOp()); + })) { + const ShapedType outputType = dyn_cast(getOutput().getType()); + if (!outputType || !outputType.hasStaticShape()) + return {}; + + // A 'small' output is currently defined as 1D and <= 6 elements + // (tosa_level_8k MAX_RANK) + if (outputType.getRank() != 1) + return {}; + + const int64_t outputNumElements = outputType.getNumElements(); + if (outputNumElements > 6) + return {}; + + llvm::SmallVector constOperands; + constOperands.reserve(outputNumElements); + for (const Attribute &operand : adaptor.getOperands()) { + const auto elementsAttr = + llvm::dyn_cast_if_present(operand); + if (!elementsAttr) + return {}; + + constOperands.append( + llvm::to_vector(elementsAttr.getValues())); + } + + return DenseElementsAttr::get(outputType, constOperands); + } + // Fold consecutive concats on the same axis into a single op. // Keep track of the operands so we are able to construct a new concat // later. Conservatively assume that we double the number of operands when // folding SmallVector concatOperands; - concatOperands.reserve(2 * getNumOperands()); + concatOperands.reserve(2 * numOperands); // Find all operands that are foldable concats bool foundFoldableConcat = false; - for (Value operand : getOperands()) { + for (Value operand : operands) { concatOperands.emplace_back(operand); auto producer = dyn_cast_or_null(operand.getDefiningOp()); diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir index ec54f27346c8b..f776b62a3da31 100644 --- a/mlir/test/Dialect/Tosa/fold_concats.mlir +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -91,3 +91,58 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x // CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> // CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32> // CHECK: } + +// ----- + +// CHECK-LABEL: test_fold_small_const_concat +func.func @test_fold_small_const_concat() -> tensor<6xi8> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi8>}> : () -> tensor<6xi8> + // CHECK: return %[[VAL_0]] : tensor<6xi8> + %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi8>}> : () -> tensor<2xi8> + %1 = "tosa.const"() <{values = dense<[3, 4, 5]> : tensor<3xi8>}> : () -> tensor<3xi8> + %2 = "tosa.const"() <{values = dense<6> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = "tosa.concat"(%0, %1, %2) <{axis = 0 : i32}> : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8> + func.return %3 : tensor<6xi8> +} + +// ----- + +// CHECK-LABEL: test_no_fold_small_const_concat_with_non_const +func.func @test_no_fold_small_const_concat_with_non_const(%arg0: tensor<2xi8>, %arg1: tensor<3xi8>, %arg2: tensor<1xi8>) -> tensor<6xi8> { + // CHECK: %[[VAL_3:.*]] = tosa.concat %arg0, %arg1, %arg2 {axis = 0 : i32} : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8> + // CHECK: return %[[VAL_3]] : tensor<6xi8> + %1 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 0 : i32}> : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8> + func.return %1 : tensor<6xi8> +} + +// ----- + +// CHECK-LABEL: test_no_fold_small_const_concat_with_higher_dim +func.func @test_no_fold_small_const_concat_with_higher_dim() -> tensor<7xi8> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3xi8> + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<[4, 5, 6]> : tensor<3xi8>}> : () -> tensor<3xi8> + // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK-DAG: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] {axis = 0 : i32} : (tensor<3xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<7xi8> + // CHECK: return %[[VAL_3]] : tensor<7xi8> + %0 = "tosa.const"() <{values = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3xi8> + %1 = "tosa.const"() <{values = dense<[4, 5, 6]> : tensor<3xi8>}> : () -> tensor<3xi8> + %2 = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = "tosa.concat"(%0, %1, %2) <{axis = 0 : i32}> : (tensor<3xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<7xi8> + func.return %3 : tensor<7xi8> +} + +// ----- + +// CHECK-LABEL: test_no_fold_small_const_concat_with_higher_rank +func.func @test_no_fold_small_const_concat_with_higher_rank() -> tensor<1x6xi8> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8> + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}3, 4, 5]]> : tensor<1x3xi8>}> : () -> tensor<1x3xi8> + // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<6> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> + // CHECK-DAG: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2xi8>, tensor<1x3xi8>, tensor<1x1xi8>) -> tensor<1x6xi8> + // CHECK: return %[[VAL_3]] : tensor<1x6xi8> + %0 = "tosa.const"() <{values = dense<[[1, 2]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8> + %1 = "tosa.const"() <{values = dense<[[3, 4, 5]]> : tensor<1x3xi8>}> : () -> tensor<1x3xi8> + %2 = "tosa.const"() <{values = dense<[[6]]> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> + %3 = "tosa.concat"(%0, %1, %2) <{axis = 1 : i32}> : (tensor<1x2xi8>, tensor<1x3xi8>, tensor<1x1xi8>) -> tensor<1x6xi8> + func.return %3 : tensor<1x6xi8> +}