Skip to content

Commit fe83a72

Browse files
authored
[TOSA] Introduce Tosa_ElementwiseUnaryOp with Type and Shape Enforcement (#115784)
* Enforce that Tosa_ElementwiseUnaryOp requires output tensors to match the input tensor's type and shape. * Update the following ops to conform to Tosa_ElementwiseUnaryOp: clamp, erf, sigmoid, tanh, cos, sin, abs, bitwise_not, ceil, clz, exp, floor, log, logical_not, negate, reciprocal, rsqrt. * Add invalid tests for each operator to ensure compliance with TOSA v1.0 Specification. Signed-off-by: Peng Sun <[email protected]>
1 parent d2db9bd commit fe83a72

File tree

4 files changed

+275
-27
lines changed

4 files changed

+275
-27
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
234234
"operands attr-dict `:` functional-type(operands, results)";
235235
}
236236

237+
class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
238+
Tosa_ElementwiseOp<mnemonic, !listconcat(traits, [
239+
SameOperandsAndResultShape,
240+
SameOperandsAndResultElementType])> {}
241+
237242
class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>
238243
: Tosa_Op<mnemonic, !listconcat(traits, [InferTensorTypeAdaptor, Pure])> {
239244
let assemblyFormat =

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
367367
//===----------------------------------------------------------------------===//
368368
// Operator: clamp
369369
//===----------------------------------------------------------------------===//
370-
def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
370+
def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
371371
let summary = "Computes clamp(features, min, max).";
372372

373373
let description = [{
@@ -397,7 +397,7 @@ def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
397397
//===----------------------------------------------------------------------===//
398398
// Operator: sigmoid
399399
//===----------------------------------------------------------------------===//
400-
def Tosa_SigmoidOp : Tosa_ElementwiseOp<"sigmoid"> {
400+
def Tosa_SigmoidOp : Tosa_ElementwiseUnaryOp<"sigmoid"> {
401401
let summary = "Computes elementwise sigmoid of input.";
402402

403403
let description = [{
@@ -420,7 +420,7 @@ def Tosa_SigmoidOp : Tosa_ElementwiseOp<"sigmoid"> {
420420
//===----------------------------------------------------------------------===//
421421
// Operator: tanh
422422
//===----------------------------------------------------------------------===//
423-
def Tosa_TanhOp : Tosa_ElementwiseOp<"tanh", [SameOperandsAndResultElementType]> {
423+
def Tosa_TanhOp : Tosa_ElementwiseUnaryOp<"tanh"> {
424424
let summary = "Computes elementwise hyperbolic tangent of input";
425425

426426
let description = [{
@@ -442,10 +442,7 @@ def Tosa_TanhOp : Tosa_ElementwiseOp<"tanh", [SameOperandsAndResultElementType]>
442442
//===----------------------------------------------------------------------===//
443443
// Operator: erf
444444
//===----------------------------------------------------------------------===//
445-
def Tosa_ErfOp : Tosa_Op<"erf", [
446-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
447-
["inferReturnTypeComponents"]>,
448-
Pure]> {
445+
def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> {
449446
let summary = "Computes gauss error function of input";
450447

451448
let description = [{
@@ -906,7 +903,7 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
906903
//===----------------------------------------------------------------------===//
907904
// Operator: abs
908905
//===----------------------------------------------------------------------===//
909-
def Tosa_AbsOp : Tosa_ElementwiseOp<"abs", [SameOperandsAndResultElementType]> {
906+
def Tosa_AbsOp : Tosa_ElementwiseUnaryOp<"abs"> {
910907
let summary = "Elementwise abs op";
911908

912909
let description = [{
@@ -933,8 +930,7 @@ def Tosa_AbsOp : Tosa_ElementwiseOp<"abs", [SameOperandsAndResultElementType]> {
933930
//===----------------------------------------------------------------------===//
934931
// Operator: bitwise_not
935932
//===----------------------------------------------------------------------===//
936-
def Tosa_BitwiseNotOp : Tosa_ElementwiseOp<"bitwise_not",
937-
[SameOperandsAndResultElementType]> {
933+
def Tosa_BitwiseNotOp : Tosa_ElementwiseUnaryOp<"bitwise_not"> {
938934
let summary = "Bitwise NOT operator";
939935

940936
let description = [{
@@ -953,7 +949,7 @@ def Tosa_BitwiseNotOp : Tosa_ElementwiseOp<"bitwise_not",
953949
//===----------------------------------------------------------------------===//
954950
// Operator: ceil
955951
//===----------------------------------------------------------------------===//
956-
def Tosa_CeilOp : Tosa_ElementwiseOp<"ceil", [SameOperandsAndResultElementType]> {
952+
def Tosa_CeilOp : Tosa_ElementwiseUnaryOp<"ceil"> {
957953
let summary = "Elementwise ceil op";
958954

959955
let description = [{
@@ -972,7 +968,7 @@ def Tosa_CeilOp : Tosa_ElementwiseOp<"ceil", [SameOperandsAndResultElementType]>
972968
//===----------------------------------------------------------------------===//
973969
// Operator: clz
974970
//===----------------------------------------------------------------------===//
975-
def Tosa_ClzOp : Tosa_ElementwiseOp<"clz", [SameOperandsAndResultElementType]> {
971+
def Tosa_ClzOp : Tosa_ElementwiseUnaryOp<"clz"> {
976972
let summary = "Elementwise count leading zero op";
977973

978974
let description = [{
@@ -991,8 +987,7 @@ def Tosa_ClzOp : Tosa_ElementwiseOp<"clz", [SameOperandsAndResultElementType]> {
991987
//===----------------------------------------------------------------------===//
992988
// Operator: cos
993989
//===----------------------------------------------------------------------===//
994-
def Tosa_CosOp : Tosa_ElementwiseOp<"cos",
995-
[SameOperandsAndResultElementType]> {
990+
def Tosa_CosOp : Tosa_ElementwiseUnaryOp<"cos"> {
996991
let summary = "Elementwise cos op";
997992

998993
let description = [{
@@ -1011,7 +1006,7 @@ def Tosa_CosOp : Tosa_ElementwiseOp<"cos",
10111006
//===----------------------------------------------------------------------===//
10121007
// Operator: exp
10131008
//===----------------------------------------------------------------------===//
1014-
def Tosa_ExpOp : Tosa_ElementwiseOp<"exp", [SameOperandsAndResultElementType]> {
1009+
def Tosa_ExpOp : Tosa_ElementwiseUnaryOp<"exp"> {
10151010
let summary = "Elementwise exp op";
10161011

10171012
let description = [{
@@ -1032,7 +1027,7 @@ def Tosa_ExpOp : Tosa_ElementwiseOp<"exp", [SameOperandsAndResultElementType]> {
10321027
//===----------------------------------------------------------------------===//
10331028
// Operator: floor
10341029
//===----------------------------------------------------------------------===//
1035-
def Tosa_FloorOp : Tosa_ElementwiseOp<"floor", [SameOperandsAndResultElementType]> {
1030+
def Tosa_FloorOp : Tosa_ElementwiseUnaryOp<"floor"> {
10361031
let summary = "Elementwise floor op";
10371032

10381033
let description = [{
@@ -1051,7 +1046,7 @@ def Tosa_FloorOp : Tosa_ElementwiseOp<"floor", [SameOperandsAndResultElementType
10511046
//===----------------------------------------------------------------------===//
10521047
// Operator: log
10531048
//===----------------------------------------------------------------------===//
1054-
def Tosa_LogOp : Tosa_ElementwiseOp<"log", [SameOperandsAndResultElementType]> {
1049+
def Tosa_LogOp : Tosa_ElementwiseUnaryOp<"log"> {
10551050
let summary = "Elementwise log op";
10561051

10571052
let description = [{
@@ -1072,8 +1067,7 @@ def Tosa_LogOp : Tosa_ElementwiseOp<"log", [SameOperandsAndResultElementType]> {
10721067
//===----------------------------------------------------------------------===//
10731068
// Operator: logical_not
10741069
//===----------------------------------------------------------------------===//
1075-
def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
1076-
[SameOperandsAndResultElementType]> {
1070+
def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
10771071
let summary = "Returns the truth value of NOT x element-wise.";
10781072

10791073
let description = [{
@@ -1092,8 +1086,7 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
10921086
//===----------------------------------------------------------------------===//
10931087
// Operator: negate
10941088
//===----------------------------------------------------------------------===//
1095-
def Tosa_NegateOp : Tosa_ElementwiseOp<"negate",
1096-
[SameOperandsAndResultElementType]> {
1089+
def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
10971090
let summary = "Elementwise negate op";
10981091

10991092
let description = [{
@@ -1117,8 +1110,7 @@ def Tosa_NegateOp : Tosa_ElementwiseOp<"negate",
11171110
//===----------------------------------------------------------------------===//
11181111
// Operator: reciprocal
11191112
//===----------------------------------------------------------------------===//
1120-
def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
1121-
[SameOperandsAndResultElementType]> {
1113+
def Tosa_ReciprocalOp : Tosa_ElementwiseUnaryOp<"reciprocal"> {
11221114
let summary = "Elementwise reciprocal op";
11231115

11241116
let description = [{
@@ -1149,8 +1141,7 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
11491141
//===----------------------------------------------------------------------===//
11501142
// Operator: rsqrt
11511143
//===----------------------------------------------------------------------===//
1152-
def Tosa_RsqrtOp : Tosa_ElementwiseOp<"rsqrt",
1153-
[SameOperandsAndResultElementType]> {
1144+
def Tosa_RsqrtOp : Tosa_ElementwiseUnaryOp<"rsqrt"> {
11541145
let summary = "Elementwise 1/sqrt op";
11551146

11561147
let description = [{
@@ -1170,8 +1161,7 @@ def Tosa_RsqrtOp : Tosa_ElementwiseOp<"rsqrt",
11701161
//===----------------------------------------------------------------------===//
11711162
// Operator: sin
11721163
//===----------------------------------------------------------------------===//
1173-
def Tosa_SinOp : Tosa_ElementwiseOp<"sin",
1174-
[SameOperandsAndResultElementType]> {
1164+
def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
11751165
let summary = "Elementwise sin op";
11761166

11771167
let description = [{

0 commit comments

Comments
 (0)