Skip to content

Commit a55906e

Browse files
justin-ngo-armvivekkhandelwal1
authored andcommitted
[TOSA] Refactor: Use tosaCastTensorToType() function to create tosa.cast
* Previously, in Torch to TOSA, there are 3 ways to create tosa.cast op: - `rewriter.create<tosa::CastOp>()` - `tosa::promoteType()` - `tosa::tosaCastTensorToType()` * This commit combines the three APIs above into `tosa::tosaCastTensorToType()` with the following features: - Checking whether source and destination element types are the same before casting. If they are same, skip the cast. - Custom float to integer cast behavior added from this PR: #3946 TLDR: PyTorch's and TOSA's float to integer casting behaviors are different (round to zero vs round to nearest, respectively), which requires a custom casting here. - Future `TODO`: add a --strict mode which includes `checkValidityOfCast()` to ensure that the casting pairs follow TOSA specifications. * Update LIT tests. Signed-off-by: Justin Ngo <[email protected]> Change-Id: I2aef3c79d8f2d98b93e671d5b815b8eab33e697e
1 parent d70e7b1 commit a55906e

File tree

6 files changed

+263
-267
lines changed

6 files changed

+263
-267
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
2424
SmallVector<int64_t> indiceOneDimShape, int32_t dim,
2525
ArrayRef<int64_t> indexShape);
2626

27+
// Default function to create TOSA op with shift value
2728
mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
2829
TensorType outType, Value lhs, Value rhs,
2930
int32_t shift);
@@ -32,8 +33,8 @@ mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
3233
template <typename TosaOpT>
3334
TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
3435
TensorType outType, Value lhs, Value rhs) {
35-
lhs = promoteType(rewriter, lhs, outType);
36-
rhs = promoteType(rewriter, rhs, outType);
36+
lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value();
37+
rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value();
3738
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
3839
}
3940

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
6363
ArrayRef<T> vec, ArrayRef<int64_t> shape,
6464
std::optional<Type> dtype = {});
6565

66-
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
67-
Value src, Type destType, Value &result);
68-
69-
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
66+
// Default function to create tosa.cast op. This should be called instead of
67+
// directly calling rewriter.create<tosa::CastOp>.
68+
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
69+
TensorType destType);
7070

7171
// Creates a TOSA operation and performs shape inference on the individual
7272
// op. This allows shape inference during the framework to TOSA lowering.

0 commit comments

Comments
 (0)