Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 6585 files
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t> indiceOneDimShape, int32_t dim,
ArrayRef<int64_t> indexShape);

// Default function to create TOSA op with shift value
mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs,
int32_t shift);
Expand All @@ -32,8 +33,8 @@ mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
template <typename TosaOpT>
TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs) {
lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value();
rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value();
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
}

Expand Down
74 changes: 24 additions & 50 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H

#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace tosa {
Expand Down Expand Up @@ -49,6 +50,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);

// Create an int8_t const tosa.mul shift tensor from an int
Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
int32_t shift);

// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type);
Expand All @@ -62,55 +67,24 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape,
std::optional<Type> dtype = {});

LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result);

Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
// Default function to create tosa.cast op. This should be called instead of
// directly calling rewriter.create<tosa::CastOp>.
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
TensorType destType);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(ImplicitLocOpBuilder &builder, Type result_ty,
Args &&...args) {
return CreateOpAndInferShape<TosaOp>(builder, result_ty, args...);
}

template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
Args &&...args) {
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);

InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
if (!shapeInterface)
return op;

SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(op.getContext(), op.getLoc(),
op->getOperands(), op->getAttrDictionary(),
op->getPropertiesStorage(),
op->getRegions(), returnedShapes)
.failed())
return op;

// We need to use the element type of the existing result type to generate
// the new result shaped type. This is because rescale can include a cast to
// different bit-width types and does not have a TypeAttr to define the
// target type.
auto result = op->getResult(0);
auto predictedShape = returnedShapes[0];
auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);

// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = cast<ShapedType>(result_ty).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}

// Compute the new type based on the joined version.
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
auto new_ty = newKnowledge.getType();
result.setType(new_ty);
return op;
ImplicitLocOpBuilder builder(loc, rewriter);
return CreateOpAndInfer<TosaOp>(builder, result_ty, args...);
}

template <typename TosaOp, typename... Args>
Expand Down
Loading
Loading