Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_TOSA_IR_TOSAOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
Expand Down Expand Up @@ -53,34 +54,43 @@ class MulOperandsAndResultElementType
: public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
auto resElemType = getElementTypeOrSelf(op->getResult(0));

// In cases of floating point type, op requires the same element
// type for all operands and result.
if (llvm::isa<FloatType>(resElemType))
return impl::verifySameOperandsAndResultElementType(op);

// Check we have a single result.
if (failed(impl::verifyOneResult(op)))
return failure();
Type resElemType = getElementTypeOrSelf(op->getResult(0));

// Check we have lhs and rhs.
if (failed(impl::verifyAtLeastNOperands(op, 2)))
return failure();

Type lhsElemType = getElementTypeOrSelf(op->getOperand(0));
Type rhsElemType = getElementTypeOrSelf(op->getOperand(1));

// Check that for i32 a shift has been explicitly provided.
if (lhsElemType.isInteger(32) && failed(impl::verifyNOperands(op, 3)))
return failure();

// Verify operands type match (ignoring the shift parameter which will
// always be i8).
if (lhsElemType != rhsElemType)
return op->emitOpError("requires the same element type for all operands");

// Though the spec requires the element type of result to be i32, a more
// relaxed way is provided at dialect level for easier cooperating with
// other dialects.
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
IntegerType lhsIntType =
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
IntegerType rhsIntType =
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
if (lhsIntType != rhsIntType)
return op->emitOpError(
"requires the same element type for all operands");

// Though the spec requires the element type of result to be i32, a more
// relaxed way is provided at dialect level for easier cooperating with
// other dialects.
auto lhsIntType = cast<IntegerType>(lhsElemType);
if (lhsIntType.getWidth() > resIntType.getWidth())
return op->emitOpError("invalid data type size for operands or result");

return success();
} else {
// In cases of floating point type or quant types, op requires the same
// element type for all operands and result (excluding shift).
if (resElemType != lhsElemType)
return op->emitOpError(
"requires the same element type for all operands and results");
}

// In cases of all other types, op requires the same element
// type for all operands and result.
return impl::verifySameOperandsAndResultElementType(op);
return llvm::success();
}
};

Expand Down
12 changes: 9 additions & 3 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -800,9 +800,11 @@ def MulOperandsAndResultElementType :
//===----------------------------------------------------------------------===//
// Operator: mul
//===----------------------------------------------------------------------===//
def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
def Tosa_MulOp : Tosa_Op<"mul", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Commutative,
MulOperandsAndResultElementType]> {
Pure]> {
let summary = "Multiplication operator";

let description = [{
Expand All @@ -814,7 +816,8 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
I8Attr:$shift
// Apply right shift on i32_t input data only
Tosa_ScalarInt8Tensor:$shift
);

let results = (outs
Expand All @@ -823,6 +826,9 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [

let hasFolder = 1;
let hasVerifier = 1;

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}

//===----------------------------------------------------------------------===//
Expand Down
11 changes: 10 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[
IsRankedTensorTypePred,
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;

def AllDimensionsAreSizeOne : And<[
IsRankedTensorTypePred,
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;

// AMD: removed HasNo0Dimensions constraint below to allow lowerings
// in onnx-mlir like onnx.Split.
class TosaTensorOf<
Expand All @@ -111,6 +115,11 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;

class TosaScalarTensorOf<list<Type> allowedTypes, list<int> ranks>
: TosaRankedTensorOf<allowedTypes,
[HasAnyRankOfPred<ranks>, AllDimensionsAreSizeOne],
"tosa-conformant scalar tensor">;

//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -139,8 +148,8 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//

// Rank-0 (scalar) tensor
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;

// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
Expand Down
90 changes: 56 additions & 34 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,43 +100,59 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}

// tosa::MulOp
if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::MulFOp>(loc, resultTypes, args);

if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
Value a = args[0];
Value b = args[1];
auto shift =
cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
if (shift > 0) {
auto shiftConst =
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
if (!a.getType().isInteger(32))
a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);

if (!b.getType().isInteger(32))
b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);

auto result = rewriter.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter.getBoolAttr(false));

if (elementTy.isInteger(32))
return result;

return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
if (isa<tosa::MulOp>(op)) {
auto shift_val = cast<tosa::MulOp>(op).getShift();
ElementsAttr shift_elem;
if (!shift_val.getImpl() ||
!matchPattern(shift_val, m_Constant(&shift_elem))) {
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
}

int aWidth = a.getType().getIntOrFloatBitWidth();
int bWidth = b.getType().getIntOrFloatBitWidth();
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();

if (aWidth < cWidth)
a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
if (bWidth < cWidth)
b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
(void)rewriter.notifyMatchFailure(op,
"Cannot have shift value for float");
return nullptr;
}
return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
}

if (isa<IntegerType>(elementTy)) {
Value a = args[0];
Value b = args[1];

if (shift > 0) {
auto shiftConst =
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
if (!a.getType().isInteger(32))
a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);

if (!b.getType().isInteger(32))
b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);

auto result = rewriter.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter.getBoolAttr(false));

if (elementTy.isInteger(32))
return result;

return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
}

int aWidth = a.getType().getIntOrFloatBitWidth();
int bWidth = b.getType().getIntOrFloatBitWidth();
int cWidth = resultTypes[0].getIntOrFloatBitWidth();

if (aWidth < cWidth)
a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
if (bWidth < cWidth)
b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);

return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
}
}

// tosa::NegateOp
Expand Down Expand Up @@ -990,7 +1006,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
auto loc = operation->getLoc();
auto rank =
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
// For the mul op we need to avoid expanding the rank of the optional shift
// input.
auto operandsToExpand =
isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;

auto expandedOperands =
expandInputRanks(rewriter, loc, operandsToExpand, rank);
auto [targetShape, masterOperands] =
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
auto broadcastOperands = broadcastDynamicDimensions(
Expand Down
15 changes: 13 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,18 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());

const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
// Result right shift on i32_t data type only. For simplification, synthesize
// a zero shift for other data type.
int32_t shift = 0;
if (resultETy.isInteger(32)) {
ElementsAttr shift_elem;
if (getShift().getImpl()) {
if (!matchPattern(getShift(), m_Constant(&shift_elem)))
// cannot be folded when the shift value is unknown.
return {};
shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
}
}

if (rhsTy == resultTy) {
if (isSplatZero(resultETy, lhsAttr))
Expand All @@ -1245,7 +1256,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return lhs;
}

return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
}

OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
Expand Down
Loading