Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 5542 files
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t> indiceOneDimShape, int32_t dim,
ArrayRef<int64_t> indexShape);

// Default function to create TOSA op with shift value
mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs,
int32_t shift);
Expand All @@ -32,8 +33,8 @@ mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
template <typename TosaOpT>
TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs) {
lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value();
rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value();
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
}

Expand Down
74 changes: 24 additions & 50 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H

#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace tosa {
Expand Down Expand Up @@ -45,6 +46,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);

// Create an int8_t const tosa.mul shift tensor from an int
Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
int32_t shift);

// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type);
Expand All @@ -58,55 +63,24 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape,
std::optional<Type> dtype = {});

LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result);

Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
// Default function to create tosa.cast op. This should be called instead of
// directly calling rewriter.create<tosa::CastOp>.
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
TensorType destType);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(ImplicitLocOpBuilder &builder, Type result_ty,
Args &&...args) {
return CreateOpAndInferShape<TosaOp>(builder, result_ty, args...);
}

template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
Args &&...args) {
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);

InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
if (!shapeInterface)
return op;

SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(op.getContext(), op.getLoc(),
op->getOperands(), op->getAttrDictionary(),
op->getPropertiesStorage(),
op->getRegions(), returnedShapes)
.failed())
return op;

// We need to use the element type of the existing result type to generate
// the new result shaped type. This is because rescale can include a cast to
// different bit-width types and does not have a TypeAttr to define the
// target type.
auto result = op->getResult(0);
auto predictedShape = returnedShapes[0];
auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);

// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = cast<ShapedType>(result_ty).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}

// Compute the new type based on the joined version.
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
auto new_ty = newKnowledge.getType();
result.setType(new_ty);
return op;
ImplicitLocOpBuilder builder(loc, rewriter);
return CreateOpAndInfer<TosaOp>(builder, result_ty, args...);
}

template <typename TosaOp, typename... Args>
Expand Down
Loading
Loading