Skip to content

Commit 49a149a

Browse files
committed
Adjust for LLVM bump_to_a58e774f
1 parent 89e85e1 commit 49a149a

File tree

7 files changed

+195
-133
lines changed

7 files changed

+195
-133
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
4949
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
5050
float val);
5151

52+
// Create an int8_t const tosa.mul shift tensor from an int when required for
53+
// the given result type. Returns a null Value when no shift operand is needed.
54+
Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
55+
Type resultType, int32_t shift);
56+
5257
// Create a zero constant tensor of the desired type and shape.
5358
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
5459
Operation *op, Type type);

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 87 additions & 63 deletions
Large diffs are not rendered by default.

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,13 @@ tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
119119
int32_t shift) {
120120
lhs = promoteType(rewriter, lhs, outType);
121121
rhs = promoteType(rewriter, rhs, outType);
122+
auto constShift =
123+
tosa::getTosaMulShiftConstTensor(rewriter, op, outType, shift);
124+
if (constShift)
125+
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
126+
lhs, rhs, constShift);
122127
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
123-
lhs, rhs, shift);
128+
lhs, rhs, Value());
124129
}
125130

126131
template <>
@@ -386,10 +391,13 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
386391
// Multiply the coefficients by the coordinates
387392
// %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>,
388393
// tensor<3xi32>) -> tensor<8x3xi32>
394+
auto flattenedMulType =
395+
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType());
396+
auto mulShift =
397+
tosa::getTosaMulShiftConstTensor(rewriter, op, flattenedMulType, 0);
389398
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
390-
rewriter, op->getLoc(),
391-
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
392-
indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0);
399+
rewriter, op->getLoc(), flattenedMulType, indicesMatrixReshapeOp,
400+
flattenedCoeffValue.value(), mulShift);
393401

394402
// Sum up the products of the coefficients and coordinates
395403
// %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) ->
@@ -657,10 +665,13 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
657665
// [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]]
658666
// %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>,
659667
// tensor<2xi32>) -> tensor<3x2xi32>
668+
auto flattenedMulType =
669+
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType());
670+
auto mulShift =
671+
tosa::getTosaMulShiftConstTensor(rewriter, op, flattenedMulType, 0);
660672
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
661-
rewriter, op->getLoc(),
662-
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
663-
indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0);
673+
rewriter, op->getLoc(), flattenedMulType, indicesMatrixReshapeOp,
674+
flattenedCoeffValue.value(), mulShift);
664675

665676
// Sum up the products of the coefficients and coordinates
666677
// [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]]
@@ -1006,8 +1017,10 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
10061017
.failed())
10071018
return std::nullopt;
10081019

1020+
auto mulShift =
1021+
tosa::getTosaMulShiftConstTensor(rewriter, op, output_type, 0);
10091022
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
1010-
val.value(), div_const, 0)
1023+
val.value(), div_const, mulShift)
10111024
.getResult();
10121025
}
10131026

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,26 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
163163
return const_op.getResult();
164164
}
165165

166+
Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
167+
Type resultType, int32_t shift) {
168+
auto tensorType = dyn_cast_or_null<TensorType>(resultType);
169+
if (!tensorType)
170+
return Value();
171+
172+
auto elementType = tensorType.getElementType();
173+
if (!isa<IntegerType>(elementType))
174+
return Value();
175+
176+
auto shiftType = RankedTensorType::get({1}, rewriter.getI8Type());
177+
auto shiftAttr = DenseElementsAttr::get<int8_t>(
178+
shiftType, llvm::ArrayRef<int8_t>{static_cast<int8_t>(shift)});
179+
180+
auto constShift =
181+
rewriter.create<tosa::ConstOp>(op->getLoc(), shiftType, shiftAttr);
182+
183+
return constShift.getResult();
184+
}
185+
166186
// Create a zero constant tensor of the desired type and shape.
167187
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
168188
Operation *op, Type type) {

0 commit comments

Comments
 (0)