Skip to content

Conversation

@Tai78641
Copy link
Contributor

This commit extends the slice folder to fold constant slice operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements.

This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK).

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This commit extends the slice folder to fold constant slice operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements.

This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK).


Full diff: https://github.com/llvm/llvm-project/pull/128193.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+30-7)
  • (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+13)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..f5a21689c0af3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1115,18 +1115,41 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
     return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
   }
 
-  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
-      outputTy.getNumElements() == 1) {
-    DenseElementsAttr startElems;
-    if (!matchPattern(getStart(), m_Constant(&startElems)))
-      return {};
+  if (!inputTy.hasStaticShape() || !outputTy.hasStaticShape())
+    return {};
+
+  DenseElementsAttr startElems;
+  if (!matchPattern(getStart(), m_Constant(&startElems)))
+    return {};
 
-    llvm::SmallVector<uint64_t> indices =
-        llvm::to_vector(startElems.getValues<uint64_t>());
+  llvm::SmallVector<uint64_t> indices =
+      llvm::to_vector(startElems.getValues<uint64_t>());
+
+  if (outputTy.getNumElements() == 1) {
     auto value = operand.getValues<Attribute>()[indices];
     return SplatElementsAttr::get(outputTy, value);
   }
 
+  DenseElementsAttr size_elems;
+  if (!matchPattern(getSize(), m_Constant(&size_elems)))
+    return {};
+  const llvm::SmallVector<uint64_t> sizes =
+      llvm::to_vector(size_elems.getValues<uint64_t>());
+
+  // Fold slice when all operands are constant and the output is 'small'
+  // A 'small' output is currently defined as 1D and <= 6 elements
+  // (tosa_level_8k MAX_RANK)
+  if (inputTy.getRank() == 1 && outputTy.getRank() == 1 &&
+      outputTy.getNumElements() <= 6 && indices.size() == 1 &&
+      sizes.size() == 1) {
+    const auto begin = operand.value_begin<Attribute>();
+    const uint64_t offset = indices[0];
+    const uint64_t size = sizes[0];
+    const SmallVector<Attribute> slicedValues(begin + offset,
+                                              begin + offset + size);
+    return DenseElementsAttr::get(outputTy, slicedValues);
+  }
+
   return {};
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 3ff3121348fca..c82f522432295 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -21,3 +21,16 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
   %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
   return
 }
+
+// -----
+
+// CHECK-LABEL: test_1d_slice
+func.func @test_1d_slice() -> tensor<6xi32> {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>}> : () -> tensor<6xi32>
+  // CHECK: return %[[VAL_0]] : tensor<6xi32>
+  %0 = "tosa.const"() <{value = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32>
+  %1 = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %2 = tosa.const_shape {value = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32>
+  return %3 : tensor<6xi32>
+}

Copy link
Contributor

@FranklandJack FranklandJack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove the change ID and the TF reference from the commit message?


llvm::SmallVector<uint64_t> indices =
llvm::to_vector(startElems.getValues<uint64_t>());
llvm::SmallVector<uint64_t> indices =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we tend to eschew auto in favour of explict types (grumble, grumble). But I think here auto is probably justified since we already have the type name on the RHS, that is llvm::to_vector.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


DenseElementsAttr size_elems;
if (!matchPattern(getSize(), m_Constant(&size_elems)))
return {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we put a newline after this return, it makes it a wee bit easier to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

DenseElementsAttr size_elems;
if (!matchPattern(getSize(), m_Constant(&size_elems)))
return {};
const llvm::SmallVector<uint64_t> sizes =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be auto'd , see my above comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
return
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some negative tests here as well?

The condition for this transformation is:


  // Fold slice when all operands are constant and the output is 'small'
  // A 'small' output is currently defined as 1D and <= 6 elements
  // (tosa_level_8k MAX_RANK)

So this suggests the following negative cases:

  1. 1 or more non-const operands
  2. Tensor of rank > 1
  3. > 6 element in extracted tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@Tai78641 Tai78641 requested a review from FranklandJack March 3, 2025 21:01
@Tai78641 Tai78641 force-pushed the pr_fold_slice branch 2 times, most recently from 7fd38bf to d6d3d7c Compare March 7, 2025 19:46
This commit extends the slice folder to fold constant slice operations
consisting of all constant inputs where the number of output values
does not exceed 6. Keeping the folder restricted to small inputs avoids
a large folder runtime or increased memory requirements.

This folder is useful in the context of legalizing dynamic models where
the input shapes are resolved to static directly before legalization.
In this context, constant shape operations are used over tensors of
num elements <= 6 (tosa_level_8k MAX_RANK).

Change-Id: I1e59e5919f8c2936e98788c5a9b44a691940b28a
Signed-off-by: Luke Hutton <[email protected]>
@lhutton1
Copy link
Contributor

No longer needed

@lhutton1 lhutton1 closed this May 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants