Skip to content

Conversation

@Jerry-Ge
Copy link
Member

The commit improves the concat folder to cover operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements.

This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK).

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Jerry-Ge (Jerry-Ge)

Changes

The commit improves the concat folder to cover operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements.

This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK).


Full diff: https://github.com/llvm/llvm-project/pull/128080.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+36-2)
  • (modified) mlir/test/Dialect/Tosa/fold_concats.mlir (+13)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..f31c388f71f19 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1226,16 +1226,50 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
+  const auto operands = getOperands();
+  const unsigned int numOperands = getNumOperands();
+
+  // Fold concat when all operands are constant and the output is 'small'
+  auto hasAllConstOperands = [](Value v){
+    return llvm::dyn_cast_or_null<tosa::ConstOp>(v.getDefiningOp());};
+  if (llvm::all_of(operands, hasAllConstOperands)) {
+    const ShapedType outputType = dyn_cast<ShapedType>(getOutput().getType());
+    if (!outputType || !outputType.hasStaticShape()) {
+      return {};
+    }
+
+    // A 'small' output is currently defined as 1D and <= 6 elements (tosa_level_8k MAX_RANK)
+    if (outputType.getRank() != 1) {
+      return {};
+    }
+    const int64_t outputNumElements = outputType.getNumElements();
+    if (outputNumElements > 6) {
+      return {};
+    }
+
+    llvm::SmallVector<Attribute> constOperands;
+    constOperands.reserve(outputNumElements);
+    for (const Attribute operand : adaptor.getOperands()) {
+      const auto elementsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(operand);
+      if (!elementsAttr) {
+        return {};
+      }
+      constOperands.append(llvm::to_vector(elementsAttr.getValues<Attribute>()));
+    }
+
+    return DenseElementsAttr::get(outputType, constOperands);
+  }
+
   // Fold consecutive concats on the same axis into a single op.
   // Keep track of the operands so we are able to construct a new concat
   // later. Conservatively assume that we double the number of operands when
   // folding
   SmallVector<Value, 8> concatOperands;
-  concatOperands.reserve(2 * getNumOperands());
+  concatOperands.reserve(2 * numOperands);
 
   // Find all operands that are foldable concats
   bool foundFoldableConcat = false;
-  for (Value operand : getOperands()) {
+  for (Value operand : operands) {
     concatOperands.emplace_back(operand);
 
     auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir
index ec54f27346c8b..6bfbeed81e88f 100644
--- a/mlir/test/Dialect/Tosa/fold_concats.mlir
+++ b/mlir/test/Dialect/Tosa/fold_concats.mlir
@@ -91,3 +91,16 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x
 // CHECK:           %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32>
 // CHECK:           return %[[VAL_3]] : tensor<1x4x8x8xf32>
 // CHECK:         }
+
+// -----
+
+// CHECK-LABEL: test_fold_small_const_concat
+func.func @test_fold_small_const_concat() -> tensor<6xi8> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi8>}> : () -> tensor<6xi8>
+  // CHECK: return %[[VAL_0]] : tensor<6xi8>
+  %0 = "tosa.const"() <{value = dense<[1, 2]> : tensor<2xi8>}> : () -> tensor<2xi8>
+  %1 = "tosa.const"() <{value = dense<[3, 4, 5]> : tensor<3xi8>}> : () -> tensor<3xi8>
+  %2 = "tosa.const"() <{value = dense<6> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = "tosa.concat"(%0, %1, %2) <{axis = 0 : i32}> : (tensor<2xi8>, tensor<3xi8>, tensor<1xi8>) -> tensor<6xi8>
+  func.return %3 : tensor<6xi8>
+}

@github-actions
Copy link

github-actions bot commented Feb 20, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@FranklandJack FranklandJack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove the change ID from the commit message?

@Jerry-Ge
Copy link
Member Author

Jerry-Ge commented Mar 3, 2025

Should we remove the change ID from the commit message?

I can remove that once it got approved. For gerrit patch copies, we need that.

@FranklandJack FranklandJack self-requested a review March 5, 2025 16:49
The commit improves the concat folder to cover operations consisting of
all constant inputs where the number of output values does not exceed 6.
Keeping the folder restricted to small inputs avoids a large folder
runtime or increased memory requirements.

This folder is useful in the context of legalizing dynamic models where
the input shapes are resolved to static directly before legalization.
In this context, constant shape operations are used over tensors of
num elements <= 6 (tosa_level_8k MAX_RANK).

Change-Id: Ieb522fc1d0d1ec4596ce060aa9ab76439322d6d6
Signed-off-by: Luke Hutton <[email protected]>
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hold off on this for now as a result of the discussion in #128059

@Jerry-Ge
Copy link
Member Author

No longer needed.

@Jerry-Ge Jerry-Ge closed this May 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants