diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index d1a8732dac212..3a51939e07b5b 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -77,24 +77,21 @@ computeReshapeOutput(ArrayRef higherRankShape, // Initialize new shapes with [1] * higherRank. int64_t higherRank = higherRankShape.size(); int64_t lowerRank = lowerRankShape.size(); - reshapeOutputShape.assign(higherRank, 1); int64_t higherRankDim; int64_t lowerRankDim; + const int64_t rankDiff = higherRank - lowerRank; + + for (int64_t i = lowerRank - 1; i >= 0; i--) { + higherRankDim = higherRankShape[i + rankDiff]; + lowerRankDim = lowerRankShape[i]; - for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; - i--, j--) { - higherRankDim = higherRankShape[i]; - lowerRankDim = lowerRankShape[j]; - - if (lowerRankDim == 1 && higherRankDim > 1) - reshapeOutputShape[i] = 1; - else if ((lowerRankDim > 1 && higherRankDim == 1) || - (lowerRankDim == higherRankDim)) - reshapeOutputShape[i] = lowerRankDim; - else if (higherRankDim != lowerRankDim) + if (lowerRankDim != 1 && higherRankDim != 1 && + lowerRankDim != higherRankDim) return failure(); + + reshapeOutputShape[i + rankDiff] = lowerRankDim == 1 ? 1 : lowerRankDim; } return success(); }