From 83560310ce3d05e05399e5283be961752b56774e Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 4 Nov 2024 15:03:17 +0000 Subject: [PATCH] [mlir][tosa] Fold 'small' constant 1D concat operations The commit improves the concat folder to cover operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements. This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK). Change-Id: Ieb522fc1d0d1ec4596ce060aa9ab76439322d6d6 Signed-off-by: Luke Hutton --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 39 ++++++++++++- mlir/test/Dialect/Tosa/fold_concats.mlir | 55 +++++++++++++++++++ 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 09d2c5d35263c..8c9e96d5cf784 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1186,16 +1186,51 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { } OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { + auto operands = getOperands(); + const unsigned int numOperands = getNumOperands(); + + // Fold concat when all operands are constant and the output is 'small' + if (llvm::all_of(operands, [](Value v) { + return llvm::dyn_cast_or_null(v.getDefiningOp()); + })) { + const ShapedType outputType = dyn_cast(getOutput().getType()); + if (!outputType || !outputType.hasStaticShape()) + return {}; + + // A 'small' output is currently defined as 1D and <= 6 elements + // (tosa_level_8k MAX_RANK) + if (outputType.getRank() != 1) + return {}; + + const int64_t outputNumElements = outputType.getNumElements(); + if (outputNumElements > 6) + return {}; + + llvm::SmallVector constOperands; + constOperands.reserve(outputNumElements); + for (const Attribute &operand : adaptor.getOperands()) { + const auto elementsAttr = + llvm::dyn_cast_if_present(operand); + if (!elementsAttr) + return {}; + + constOperands.append( + llvm::to_vector(elementsAttr.getValues())); + } + + return DenseElementsAttr::get(outputType, constOperands); + } + // Fold consecutive concats on the same axis into a single op. // Keep track of the operands so we are able to construct a new concat // later. Conservatively assume that we double the number of operands when // folding SmallVector concatOperands; - concatOperands.reserve(2 * getNumOperands()); + concatOperands.reserve(2 * numOperands); // Find all operands that are foldable concats bool foundFoldableConcat = false; - for (Value operand : getOperands()) { + for (Value operand : operands) { concatOperands.emplace_back(operand); auto producer = dyn_cast_or_null(operand.getDefiningOp()); diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir index ec54f27346c8b..f776b62a3da31 100644 --- a/mlir/test/Dialect/Tosa/fold_concats.mlir +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -91,3 +91,58 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x // CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> // CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32> // CHECK: } + +// ----- + +// CHECK-LABEL: test_fold_small_const_concat +func.func @test_fold_small_const_concat() -> tensor<6xi8> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi8>}> : () -> tensor<6xi8> + // CHECK: return %[[VAL_0]] : tensor<6xi8> + %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi8>}> : () -> tensor<2xi8> + %1 = "tosa.const"() <{values = dense<[3, 4, 5]> : tensor<3xi8>}> : () -> tensor<3xi8> + %2 = "tosa.const"() <{values = dense<6> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = "tosa.concat"(%0, %1, %2) <{axis = 0 : i32}> : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8> + func.return %3 : tensor<6xi8> +} + +// ----- + +// CHECK-LABEL: test_no_fold_small_const_concat_with_non_const +func.func @test_no_fold_small_const_concat_with_non_const(%arg0: tensor<2xi8>, %arg1: tensor<3xi8>, %arg2: tensor<1xi8>) -> tensor<6xi8> { + // CHECK: %[[VAL_3:.*]] = tosa.concat %arg0, %arg1, %arg2 {axis = 0 : i32} : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8> + // CHECK: return %[[VAL_3]] : tensor<6xi8> + %1 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 0 : i32}> : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8> + func.return %1 : tensor<6xi8> +} + +// ----- + +// CHECK-LABEL: test_no_fold_small_const_concat_with_higher_dim +func.func @test_no_fold_small_const_concat_with_higher_dim() -> tensor<7xi8> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3xi8> + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<[4, 5, 6]> : tensor<3xi8>}> : () -> tensor<3xi8> + // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK-DAG: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] {axis = 0 : i32} : (tensor<3xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<7xi8> + // CHECK: return %[[VAL_3]] : tensor<7xi8> + %0 = "tosa.const"() <{values = dense<[1, 2, 3]> : tensor<3xi8>}> : () -> tensor<3xi8> + %1 = "tosa.const"() <{values = dense<[4, 5, 6]> : tensor<3xi8>}> : () -> tensor<3xi8> + %2 = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = "tosa.concat"(%0, %1, %2) <{axis = 0 : i32}> : (tensor<3xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<7xi8> + func.return %3 : tensor<7xi8> +} + +// ----- + +// CHECK-LABEL: test_no_fold_small_const_concat_with_higher_rank +func.func @test_no_fold_small_const_concat_with_higher_rank() -> tensor<1x6xi8> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8> + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}3, 4, 5]]> : tensor<1x3xi8>}> : () -> tensor<1x3xi8> + // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<6> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> + // CHECK-DAG: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2xi8>, tensor<1x3xi8>, tensor<1x1xi8>) -> tensor<1x6xi8> + // CHECK: return %[[VAL_3]] : tensor<1x6xi8> + %0 = "tosa.const"() <{values = dense<[[1, 2]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8> + %1 = "tosa.const"() <{values = dense<[[3, 4, 5]]> : tensor<1x3xi8>}> : () -> tensor<1x3xi8> + %2 = "tosa.const"() <{values = dense<[[6]]> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> + %3 = "tosa.concat"(%0, %1, %2) <{axis = 1 : i32}> : (tensor<1x2xi8>, tensor<1x3xi8>, tensor<1x1xi8>) -> tensor<1x6xi8> + func.return %3 : tensor<1x6xi8> +}