Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
Tosa_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape,
TosaElementwiseOperator,
SameOperandsAndResultRank,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
Expand All @@ -250,6 +248,8 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
Tosa_ElementwiseOp<mnemonic, !listconcat(traits, [
SameOperandsAndResultShape,
ResultsBroadcastableShape,
SameOperandsAndResultRank,
SameOperandsAndResultElementType])> {}

class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>
Expand Down
56 changes: 33 additions & 23 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_TOSA_IR_TOSAOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
Expand Down Expand Up @@ -53,34 +54,43 @@ class MulOperandsAndResultElementType
: public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
auto resElemType = getElementTypeOrSelf(op->getResult(0));

// In cases of floating point type, op requires the same element
// type for all operands and result.
if (llvm::isa<FloatType>(resElemType))
return impl::verifySameOperandsAndResultElementType(op);

// Check we have a single result.
if (failed(impl::verifyOneResult(op)))
return failure();
Type resElemType = getElementTypeOrSelf(op->getResult(0));

// Check we have lhs and rhs.
if (failed(impl::verifyAtLeastNOperands(op, 2)))
return failure();

Type lhsElemType = getElementTypeOrSelf(op->getOperand(0));
Type rhsElemType = getElementTypeOrSelf(op->getOperand(1));

// Check that for i32 a shift has been explicitly provided.
if (lhsElemType.isInteger(32) && failed(impl::verifyNOperands(op, 3)))
return failure();

// Verify operands type match (ignoring the shift parameter which will
// always be i8).
if (lhsElemType != rhsElemType)
return op->emitOpError("requires the same element type for all operands");

// Though the spec requires the element type of result to be i32, a more
// relaxed way is provided at dialect level for easier cooperating with
// other dialects.
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
IntegerType lhsIntType =
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
IntegerType rhsIntType =
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
if (lhsIntType != rhsIntType)
return op->emitOpError(
"requires the same element type for all operands");

// Though the spec requires the element type of result to be i32, a more
// relaxed way is provided at dialect level for easier cooperating with
// other dialects.
auto lhsIntType = cast<IntegerType>(lhsElemType);
if (lhsIntType.getWidth() > resIntType.getWidth())
return op->emitOpError("invalid data type size for operands or result");

return success();
} else {
// In cases of floating point type or quant types, op requires the same
// element type for all operands and result (excluding shift).
if (resElemType != lhsElemType)
return op->emitOpError(
"requires the same element type for all operands and results");
}

// In cases of all other types, op requires the same element
// type for all operands and result.
return impl::verifySameOperandsAndResultElementType(op);
return llvm::success();
}
};

Expand Down
91 changes: 67 additions & 24 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> {
//===----------------------------------------------------------------------===//
def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
Commutative,
SameOperandsAndResultElementType]> {
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Elementwise addition operator";

let description = [{
Expand Down Expand Up @@ -515,8 +517,10 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
//===----------------------------------------------------------------------===//
// Operator: arithmetic_right_shift
//===----------------------------------------------------------------------===//
def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
[SameOperandsAndResultElementType]> {
def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Elementwise Arithmetic Right Shift";

let description = [{
Expand All @@ -540,7 +544,9 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
//===----------------------------------------------------------------------===//
def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
Commutative,
SameOperandsAndResultElementType]> {
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Bitwise AND operator";

let description = [{
Expand All @@ -563,7 +569,9 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
Commutative,
SameOperandsAndResultElementType]> {
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Bitwise OR operator";

let description = [{
Expand All @@ -586,7 +594,9 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
Commutative,
SameOperandsAndResultElementType]> {
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Bitwise XOR operator";

let description = [{
Expand All @@ -607,7 +617,10 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
//===----------------------------------------------------------------------===//
// Operator: int_div
//===----------------------------------------------------------------------===//
def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementType]> {
def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
ResultsBroadcastableShape,
SameOperandsAndResultRank,
SameOperandsAndResultElementType]> {
let summary = "Integer divide operator";

let description = [{
Expand All @@ -632,7 +645,9 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementT
//===----------------------------------------------------------------------===//
def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
Commutative,
SameOperandsAndResultElementType]> {
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Returns the truth value of x AND y element-wise.";

let description = [{
Expand All @@ -653,8 +668,10 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
//===----------------------------------------------------------------------===//
// Operator: logical_left_shift
//===----------------------------------------------------------------------===//
def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
[SameOperandsAndResultElementType]> {
def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Elementwise Logical Left Shift";

let description = [{
Expand All @@ -675,8 +692,10 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
//===----------------------------------------------------------------------===//
// Operator: logical_right_shift
//===----------------------------------------------------------------------===//
def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
[SameOperandsAndResultElementType]> {
def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Elementwise Logical Right Shift";

let description = [{
Expand All @@ -699,7 +718,9 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
//===----------------------------------------------------------------------===//
def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
Commutative,
SameOperandsAndResultElementType]> {
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Returns the truth value of x OR y element-wise.";

let description = [{
Expand All @@ -722,7 +743,9 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
Commutative,
SameOperandsAndResultElementType]> {
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Returns the truth value of x XOR y element-wise.";

let description = [{
Expand All @@ -745,7 +768,9 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
//===----------------------------------------------------------------------===//
def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
Commutative,
SameOperandsAndResultElementType]> {
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Elementwise Maximum";

let description = [{
Expand All @@ -769,7 +794,9 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
//===----------------------------------------------------------------------===//
def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
Commutative,
SameOperandsAndResultElementType]> {
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Elementwise Minimum";

let description = [{
Expand Down Expand Up @@ -810,7 +837,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
I8Attr:$shift
Optional<TosaTensorRankOf<[Tosa_Int8], [1]>>:$shift
);

let results = (outs
Expand All @@ -824,7 +851,10 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
//===----------------------------------------------------------------------===//
// Operator: pow
//===----------------------------------------------------------------------===//
def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Computes the power of one value to another.";

let description = [{
Expand All @@ -845,7 +875,10 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
//===----------------------------------------------------------------------===//
// Operator: sub
//===----------------------------------------------------------------------===//
def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "Elementwise subtraction operator";

let description = [{
Expand Down Expand Up @@ -1196,7 +1229,9 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
//===----------------------------------------------------------------------===//
// Operator: select
//===----------------------------------------------------------------------===//
def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
ResultsBroadcastableShape,
SameOperandsAndResultRank]> {
let summary = "Elementwise select operator";

let description = [{
Expand Down Expand Up @@ -1232,7 +1267,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
InferTensorType,
Commutative,
SameOperandsElementType]> {
ResultsBroadcastableShape,
SameOperandsElementType,
SameOperandsAndResultRank]> {
let summary = "Returns the truth value of (x == y) element-wise.";

let description = [{
Expand Down Expand Up @@ -1260,7 +1297,10 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
//===----------------------------------------------------------------------===//
// Operator: greater
//===----------------------------------------------------------------------===//
def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
ResultsBroadcastableShape,
SameOperandsElementType,
SameOperandsAndResultRank]> {
let summary = "Returns the truth value of (x > y) element-wise.";

let description = [{
Expand All @@ -1282,8 +1322,11 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
//===----------------------------------------------------------------------===//
// Operator: greater_equal
//===----------------------------------------------------------------------===//
def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
[SameOperandsElementType]> {
def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", [
ResultsBroadcastableShape,
SameOperandsElementType,
SameOperandsAndResultRank
]> {
let summary = "Returns the truth value of (x >= y) element-wise.";

let description = [{
Expand Down
Loading
Loading