Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);

// Create an int8_t const tosa.mul shift tensor from an int when required for
// the given result type. Returns a null Value when no shift operand is needed.
Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
Type resultType, int32_t shift);

// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type);
Expand Down
150 changes: 87 additions & 63 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp

Large diffs are not rendered by default.

26 changes: 18 additions & 8 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
int32_t shift) {
lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
auto constShift =
tosa::getTosaMulShiftConstTensor(rewriter, op, outType, shift);
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
lhs, rhs, shift);
lhs, rhs, constShift);
}

template <>
Expand Down Expand Up @@ -386,10 +388,13 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
// Multiply the coefficients by the coordinates
// %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>,
// tensor<3xi32>) -> tensor<8x3xi32>
auto flattenedMulType =
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType());
auto mulShift =
tosa::getTosaMulShiftConstTensor(rewriter, op, flattenedMulType, 0);
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0);
rewriter, op->getLoc(), flattenedMulType, indicesMatrixReshapeOp,
flattenedCoeffValue.value(), mulShift);

// Sum up the products of the coefficients and coordinates
// %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) ->
Expand Down Expand Up @@ -657,10 +662,13 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
// [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]]
// %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>,
// tensor<2xi32>) -> tensor<3x2xi32>
auto flattenedMulType =
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType());
auto mulShift =
tosa::getTosaMulShiftConstTensor(rewriter, op, flattenedMulType, 0);
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0);
rewriter, op->getLoc(), flattenedMulType, indicesMatrixReshapeOp,
flattenedCoeffValue.value(), mulShift);

// Sum up the products of the coefficients and coordinates
// [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]]
Expand Down Expand Up @@ -1006,8 +1014,10 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
.failed())
return std::nullopt;

auto mulShift =
tosa::getTosaMulShiftConstTensor(rewriter, op, output_type, 0);
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
val.value(), div_const, 0)
val.value(), div_const, mulShift)
.getResult();
}

Expand Down
12 changes: 12 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,18 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
return const_op.getResult();
}

Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
Type /*resultType*/, int32_t shift) {
auto shiftType = RankedTensorType::get({1}, rewriter.getI8Type());
auto shiftAttr = DenseElementsAttr::get<int8_t>(
shiftType, llvm::ArrayRef<int8_t>{static_cast<int8_t>(shift)});

auto constShift =
rewriter.create<tosa::ConstOp>(op->getLoc(), shiftType, shiftAttr);

return constShift.getResult();
}

// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type) {
Expand Down
Loading
Loading