Skip to content

Commit 6ad1623

Browse files
authored
[tosa] : Enhance EqualizeRanks to handle dynamic dimensions. (#168564)
Legalizing following IR to `tosa` using `tf-tosa-opt` from `tensorflow` repo: ``` func.func @main(%arg0: tensor<?x?x?x?xf32>) -> tensor<?x?x?x5xf32> { %0 = "tfl.pseudo_const"() <{value = dense<0.000000e+00> : tensor<5xf32>}> : () -> tensor<5xf32> %1 = tfl.add(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32> return %1 : tensor<?x?x?x5xf32> } ``` fails with ``` error: 'tosa.add' op operands don't have matching ranks %1 = tfl.add(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32> ^ tfl.mlir:3:10: note: see current operation: %1 = "tosa.add"(%arg0, %0) : (tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32> // -----// IR Dump After TosaLegalizeTFLPass Failed (tosa-legalize-tfl) //----- // "func.func"() <{function_type = (tensor<?x?x?x?xf32>) -> tensor<?x?x?x5xf32>, sym_name = "main"}> ({ ^bb0(%arg0: tensor<?x?x?x?xf32>): %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<5xf32>}> : () -> tensor<5xf32> %1 = "tosa.add"(%arg0, %0) : (tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32> "func.return"(%1) : (tensor<?x?x?x5xf32>) -> () }) : () -> () ``` This is because of the following check in `computeReshapeOutput` called from `EqualizeRanks` function: ``` if (lowerRankDim != 1 && higherRankDim != 1 && lowerRankDim != higherRankDim) return failure(); ``` Based on the broadcast semantics defined in https://mlir.llvm.org/docs/Traits/Broadcastable/#dimension-inference I think it's legal to allow `lowerRankDim != higherRankDim` if one of them is dynamic. At runtime verifier should enforce that 1. if lowerRankDim is dynamic and higherRankDim is static then the dynamic dim matches the static dim and vice-versa 2. if both are dynamic, they should match It's not necessary to error out during the op construction time.
1 parent 87a1fd1 commit 6ad1623

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ namespace {
7070
// If lower=[a], higher=[a, a], [a] reshaped into [1, a].
7171
// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
7272
// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
73+
// If lower=[c], higher=[?, ?, c], [c] reshaped into [1, 1, c].
74+
// If lower=[?], higher=[?, ?, ?], [?] reshaped into [1, 1, ?].
7375
LogicalResult
7476
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
7577
ArrayRef<int64_t> lowerRankShape,
@@ -87,7 +89,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
8789
higherRankDim = higherRankShape[i + rankDiff];
8890
lowerRankDim = lowerRankShape[i];
8991

90-
if (lowerRankDim != 1 && higherRankDim != 1 &&
92+
auto isStaticDimAndNotEqualToOne = [](int64_t dim) {
93+
return dim != 1 && dim != ShapedType::kDynamic;
94+
};
95+
96+
if (isStaticDimAndNotEqualToOne(lowerRankDim) &&
97+
isStaticDimAndNotEqualToOne(higherRankDim) &&
9198
lowerRankDim != higherRankDim)
9299
return failure();
93100

0 commit comments

Comments
 (0)