Skip to content

Commit a2f9c89

Browse files
committed
Solve merge conflicts and fixed tests after bump
Update TOSA mul legalization to always provide the new shift operand and refresh TorchToTosa tests for the updated signature.
1 parent 7e906ab commit a2f9c89

File tree

4 files changed

+47
-58
lines changed

4 files changed

+47
-58
lines changed

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,8 @@ tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
121121
rhs = promoteType(rewriter, rhs, outType);
122122
auto constShift =
123123
tosa::getTosaMulShiftConstTensor(rewriter, op, outType, shift);
124-
if (constShift)
125-
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
126-
lhs, rhs, constShift);
127124
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
128-
lhs, rhs, Value());
125+
lhs, rhs, constShift);
129126
}
130127

131128
template <>

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,7 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
164164
}
165165

166166
Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
167-
Type resultType, int32_t shift) {
168-
auto tensorType = dyn_cast_or_null<TensorType>(resultType);
169-
if (!tensorType)
170-
return Value();
171-
172-
auto elementType = tensorType.getElementType();
173-
if (!isa<IntegerType>(elementType))
174-
return Value();
175-
167+
Type /*resultType*/, int32_t shift) {
176168
auto shiftType = RankedTensorType::get({1}, rewriter.getI8Type());
177169
auto shiftAttr = DenseElementsAttr::get<int8_t>(
178170
shiftType, llvm::ArrayRef<int8_t>{static_cast<int8_t>(shift)});

0 commit comments

Comments
 (0)