diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 29afd6c27302c..4975530a9588c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -239,7 +239,9 @@ class Tosa_ElementwiseOp traits = []> : Tosa_Op, + ResultsBroadcastableShape, TosaElementwiseOperator, + SameOperandsAndResultRank, Pure])> { let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; @@ -248,8 +250,6 @@ class Tosa_ElementwiseOp traits = []> : class Tosa_ElementwiseUnaryOp traits = []> : Tosa_ElementwiseOp {} class Tosa_InferTensorTypeOp traits = []> diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 9e3e41d288e4a..c59c582a1f522 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -482,9 +482,7 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> { //===----------------------------------------------------------------------===// def Tosa_AddOp : Tosa_ElementwiseOp<"add", [ Commutative, - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { + SameOperandsAndResultElementType]> { let summary = "Elementwise addition operator"; let description = [{ @@ -517,10 +515,8 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [ //===----------------------------------------------------------------------===// // Operator: arithmetic_right_shift //===----------------------------------------------------------------------===// -def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [ - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { +def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", + [SameOperandsAndResultElementType]> { let summary = "Elementwise Arithmetic Right Shift"; let description = [{ @@ -544,9 +540,7 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [ //===----------------------------------------------------------------------===// def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [ Commutative, - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { + SameOperandsAndResultElementType]> { let summary = "Bitwise AND operator"; let description = [{ @@ -569,9 +563,7 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [ //===----------------------------------------------------------------------===// def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [ Commutative, - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { + SameOperandsAndResultElementType]> { let summary = "Bitwise OR operator"; let description = [{ @@ -594,9 +586,7 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [ //===----------------------------------------------------------------------===// def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [ Commutative, - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { + SameOperandsAndResultElementType]> { let summary = "Bitwise XOR operator"; let description = [{ @@ -617,10 +607,7 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [ //===----------------------------------------------------------------------===// // Operator: int_div //===----------------------------------------------------------------------===// -def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [ - ResultsBroadcastableShape, - SameOperandsAndResultRank, - SameOperandsAndResultElementType]> { +def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementType]> { let summary = "Integer divide operator"; let description = [{ @@ -645,9 +632,7 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [ //===----------------------------------------------------------------------===// def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [ Commutative, - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { + SameOperandsAndResultElementType]> { let summary = "Returns the truth value of x AND y element-wise."; let description = [{ @@ -668,10 +653,8 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [ //===----------------------------------------------------------------------===// // Operator: logical_left_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [ - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { +def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", + [SameOperandsAndResultElementType]> { let summary = "Elementwise Logical Left Shift"; let description = [{ @@ -692,10 +675,8 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [ //===----------------------------------------------------------------------===// // Operator: logical_right_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [ - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { +def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", + [SameOperandsAndResultElementType]> { let summary = "Elementwise Logical Right Shift"; let description = [{ @@ -718,9 +699,7 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [ //===----------------------------------------------------------------------===// def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [ Commutative, - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { + SameOperandsAndResultElementType]> { let summary = "Returns the truth value of x OR y element-wise."; let description = [{ @@ -743,9 +722,7 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [ //===----------------------------------------------------------------------===// def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [ Commutative, - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { + SameOperandsAndResultElementType]> { let summary = "Returns the truth value of x XOR y element-wise."; let description = [{ @@ -768,9 +745,7 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [ //===----------------------------------------------------------------------===// def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [ Commutative, - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { + SameOperandsAndResultElementType]> { let summary = "Elementwise Maximum"; let description = [{ @@ -794,9 +769,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [ //===----------------------------------------------------------------------===// def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [ Commutative, - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { + SameOperandsAndResultElementType]> { let summary = "Elementwise Minimum"; let description = [{ @@ -823,9 +796,11 @@ def MulOperandsAndResultElementType : //===----------------------------------------------------------------------===// // Operator: mul //===----------------------------------------------------------------------===// -def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [ +def Tosa_MulOp : Tosa_Op<"mul", [ + DeclareOpInterfaceMethods, Commutative, - MulOperandsAndResultElementType]> { + Pure]> { let summary = "Multiplication operator"; let description = [{ @@ -846,15 +821,15 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [ let hasFolder = 1; let hasVerifier = 1; + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; } //===----------------------------------------------------------------------===// // Operator: pow //===----------------------------------------------------------------------===// -def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [ - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { +def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> { let summary = "Computes the power of one value to another."; let description = [{ @@ -875,10 +850,7 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [ //===----------------------------------------------------------------------===// // Operator: sub //===----------------------------------------------------------------------===// -def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [ - ResultsBroadcastableShape, - SameOperandsAndResultElementType, - SameOperandsAndResultRank]> { +def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> { let summary = "Elementwise subtraction operator"; let description = [{ @@ -1229,9 +1201,7 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> { //===----------------------------------------------------------------------===// // Operator: select //===----------------------------------------------------------------------===// -def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [ - ResultsBroadcastableShape, - SameOperandsAndResultRank]> { +def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> { let summary = "Elementwise select operator"; let description = [{ @@ -1267,9 +1237,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [ InferTensorType, Commutative, - ResultsBroadcastableShape, - SameOperandsElementType, - SameOperandsAndResultRank]> { + SameOperandsElementType]> { let summary = "Returns the truth value of (x == y) element-wise."; let description = [{ @@ -1297,10 +1265,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [ //===----------------------------------------------------------------------===// // Operator: greater //===----------------------------------------------------------------------===// -def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [ - ResultsBroadcastableShape, - SameOperandsElementType, - SameOperandsAndResultRank]> { +def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> { let summary = "Returns the truth value of (x > y) element-wise."; let description = [{ @@ -1322,11 +1287,8 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [ //===----------------------------------------------------------------------===// // Operator: greater_equal //===----------------------------------------------------------------------===// -def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", [ - ResultsBroadcastableShape, - SameOperandsElementType, - SameOperandsAndResultRank - ]> { +def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", + [SameOperandsElementType]> { let summary = "Returns the truth value of (x >= y) element-wise."; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index c0b419b6f473c..0a10439db4080 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -958,6 +958,28 @@ LogicalResult tosa::SliceOp::verify() { return success(); } +LogicalResult tosa::MulOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + LogicalResult status = success(); + llvm::SmallVector outShape; + if (operands.size() == 2) { + status = resolveBroadcastShape(operands, outShape); + } else { + // mul op's output shape only depend on input1 and input2, not on shift + ValueShapeRange two_inputs = operands.drop_back(); + status = resolveBroadcastShape(two_inputs, outShape); + } + if (status.failed()) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + } else { + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + } + return success(); +} + LogicalResult tosa::MulOp::verify() { auto resElemType = getElementTypeOrSelf(getOutput()); @@ -1030,6 +1052,20 @@ LogicalResult tosa::MulOp::verify() { } } + // check for broadcast compatible shapes in first two operands (ignoring + // shift) + + // delegate function that returns shape of shaped type + auto getShape = [](const Type type) { + return mlir::cast(type).getShape(); + }; + SmallVector resultShape; + if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]), + getShape(rankedOperandTypes[1]), + resultShape)) { + return emitOpError("operands don't have broadcast-compatible shapes"); + } + return success(); } @@ -1670,7 +1706,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) NARY_SHAPE_INFER(tosa::LogicalXorOp) NARY_SHAPE_INFER(tosa::MaximumOp) NARY_SHAPE_INFER(tosa::MinimumOp) -NARY_SHAPE_INFER(tosa::MulOp) NARY_SHAPE_INFER(tosa::NegateOp) NARY_SHAPE_INFER(tosa::PowOp) NARY_SHAPE_INFER(tosa::ReciprocalOp) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp index e1f0a9592e8b4..520f283a3ba88 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp @@ -281,13 +281,20 @@ bool TosaReduceTransposes::collectFanIn(Operation *op, if (!llvm::isa(op) && !llvm::isa(op) && !llvm::isa(op)) { - if (!op->hasTrait()) + if (!llvm::isa(op) && + !op->hasTrait()) return false; - for (Value operand : op->getOperands()) + for (Value operand : op->getOperands()) { // If this is a problem in future, think about alternatives to recursion. + if (llvm::isa(op) && op->getNumOperands() == 3 && + operand == op->getOperand(2)) { + // do not recurse into MulOp's shift operand + continue; + } if (!collectFanIn(operand.getDefiningOp(), collected)) return false; + } } // Insert in topological order. @@ -316,7 +323,8 @@ std::optional TosaReduceTransposes::buildMappedToValue( Operation *op, const DenseMap &valuesMap, IRRewriter &rewriter, ArrayRef hoistedPerms) { if (op->getNumResults() != 1 || - !op->hasTrait()) + (!llvm::isa(op) && + !op->hasTrait())) return std::nullopt; auto resultType = llvm::cast(op->getResult(0).getType()); @@ -324,6 +332,10 @@ std::optional TosaReduceTransposes::buildMappedToValue( for (Value v : op->getOperands()) { if (valuesMap.contains(v)) { operands.push_back(valuesMap.at(v)); + } else if (llvm::isa(op) && op->getNumOperands() == 3 && + v == op->getOperand(2)) { + // special case for MulOp's shift operand + operands.push_back(v); } else { return std::nullopt; } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 51d7f82851061..ac4d466aef94b 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -183,7 +183,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor) -> t // ----- func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) { - %padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> + %padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}} %1 = tosa.pad %arg0, %padding : (tensor<13x21xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32> return @@ -211,7 +211,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso // ----- func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> + %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}} %1 = tosa.pad %arg0, %0 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32> return %1 : tensor<13x21x3xf32> @@ -749,7 +749,7 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1 // CHECK-LABEL: test_mul_missing_shift func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { - // expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}} + // this is ok because mul's shift operand is optional for now %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> return %0 : tensor<13x21x3xi32> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 9eba2f7e5a06e..a4596c8f9d536 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -327,6 +327,14 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te return %0 : tensor<13x21x3xf32> } +// ----- +// CHECK-LABEL: test_mul_scalar_with_unranked_output +func.func @test_mul_scalar_with_unranked_output(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor, tensor, tensor<1xi8>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // ----- // CHECK-LABEL: mul func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {