-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][tosa] Fold 'small' constant 1D slice operations #128193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis commit extends the slice folder to fold constant slice operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements. This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK). Full diff: https://github.com/llvm/llvm-project/pull/128193.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..f5a21689c0af3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1115,18 +1115,41 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
- if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
- outputTy.getNumElements() == 1) {
- DenseElementsAttr startElems;
- if (!matchPattern(getStart(), m_Constant(&startElems)))
- return {};
+ if (!inputTy.hasStaticShape() || !outputTy.hasStaticShape())
+ return {};
+
+ DenseElementsAttr startElems;
+ if (!matchPattern(getStart(), m_Constant(&startElems)))
+ return {};
- llvm::SmallVector<uint64_t> indices =
- llvm::to_vector(startElems.getValues<uint64_t>());
+ llvm::SmallVector<uint64_t> indices =
+ llvm::to_vector(startElems.getValues<uint64_t>());
+
+ if (outputTy.getNumElements() == 1) {
auto value = operand.getValues<Attribute>()[indices];
return SplatElementsAttr::get(outputTy, value);
}
+ DenseElementsAttr size_elems;
+ if (!matchPattern(getSize(), m_Constant(&size_elems)))
+ return {};
+ const llvm::SmallVector<uint64_t> sizes =
+ llvm::to_vector(size_elems.getValues<uint64_t>());
+
+ // Fold slice when all operands are constant and the output is 'small'
+ // A 'small' output is currently defined as 1D and <= 6 elements
+ // (tosa_level_8k MAX_RANK)
+ if (inputTy.getRank() == 1 && outputTy.getRank() == 1 &&
+ outputTy.getNumElements() <= 6 && indices.size() == 1 &&
+ sizes.size() == 1) {
+ const auto begin = operand.value_begin<Attribute>();
+ const uint64_t offset = indices[0];
+ const uint64_t size = sizes[0];
+ const SmallVector<Attribute> slicedValues(begin + offset,
+ begin + offset + size);
+ return DenseElementsAttr::get(outputTy, slicedValues);
+ }
+
return {};
}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 3ff3121348fca..c82f522432295 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -21,3 +21,16 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
return
}
+
+// -----
+
+// CHECK-LABEL: test_1d_slice
+func.func @test_1d_slice() -> tensor<6xi32> {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>}> : () -> tensor<6xi32>
+ // CHECK: return %[[VAL_0]] : tensor<6xi32>
+ %0 = "tosa.const"() <{value = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32>
+ %1 = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %2 = tosa.const_shape {value = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32>
+ return %3 : tensor<6xi32>
+}
|
FranklandJack
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove the change ID and the TF reference from the commit message?
|
|
||
| llvm::SmallVector<uint64_t> indices = | ||
| llvm::to_vector(startElems.getValues<uint64_t>()); | ||
| llvm::SmallVector<uint64_t> indices = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we tend to eschew auto in favour of explict types (grumble, grumble). But I think here auto is probably justified since we already have the type name on the RHS, that is llvm::to_vector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
|
||
| DenseElementsAttr size_elems; | ||
| if (!matchPattern(getSize(), m_Constant(&size_elems))) | ||
| return {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we put a newline after this return, it makes it a wee bit easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| DenseElementsAttr size_elems; | ||
| if (!matchPattern(getSize(), m_Constant(&size_elems))) | ||
| return {}; | ||
| const llvm::SmallVector<uint64_t> sizes = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be auto'd , see my above comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> | ||
| return | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add some negative tests here as well?
The condition for this transformation is:
// Fold slice when all operands are constant and the output is 'small'
// A 'small' output is currently defined as 1D and <= 6 elements
// (tosa_level_8k MAX_RANK)
So this suggests the following negative cases:
- 1 or more non-const operands
- Tensor of rank > 1
- > 6 element in extracted tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
7fd38bf to
d6d3d7c
Compare
This commit extends the slice folder to fold constant slice operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements. This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK). Change-Id: I1e59e5919f8c2936e98788c5a9b44a691940b28a Signed-off-by: Luke Hutton <[email protected]>
|
No longer needed |
This commit extends the slice folder to fold constant slice operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements.
This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK).