Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
Tosa_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape,
TosaElementwiseOperator,
SameOperandsAndResultRank,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
Expand All @@ -248,8 +250,6 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
Tosa_ElementwiseOp<mnemonic, !listconcat(traits, [
SameOperandsAndResultShape,
ResultsBroadcastableShape,
SameOperandsAndResultRank,
SameOperandsAndResultElementType])> {}

class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>
Expand Down
98 changes: 30 additions & 68 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,7 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> {
//===----------------------------------------------------------------------===//
def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Elementwise addition operator";

let description = [{
Expand Down Expand Up @@ -517,10 +515,8 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
//===----------------------------------------------------------------------===//
// Operator: arithmetic_right_shift
//===----------------------------------------------------------------------===//
def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
[SameOperandsAndResultElementType]> {
let summary = "Elementwise Arithmetic Right Shift";

let description = [{
Expand All @@ -544,9 +540,7 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Bitwise AND operator";

let description = [{
Expand All @@ -569,9 +563,7 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Bitwise OR operator";

let description = [{
Expand All @@ -594,9 +586,7 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Bitwise XOR operator";

let description = [{
Expand All @@ -617,10 +607,7 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
//===----------------------------------------------------------------------===//
// Operator: int_div
//===----------------------------------------------------------------------===//
def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
ResultsBroadcastableShape,
SameOperandsAndResultRank,
SameOperandsAndResultElementType]> {
def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementType]> {
let summary = "Integer divide operator";

let description = [{
Expand All @@ -645,9 +632,7 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x AND y element-wise.";

let description = [{
Expand All @@ -668,10 +653,8 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
//===----------------------------------------------------------------------===//
// Operator: logical_left_shift
//===----------------------------------------------------------------------===//
def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
[SameOperandsAndResultElementType]> {
let summary = "Elementwise Logical Left Shift";

let description = [{
Expand All @@ -692,10 +675,8 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
//===----------------------------------------------------------------------===//
// Operator: logical_right_shift
//===----------------------------------------------------------------------===//
def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
[SameOperandsAndResultElementType]> {
let summary = "Elementwise Logical Right Shift";

let description = [{
Expand All @@ -718,9 +699,7 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x OR y element-wise.";

let description = [{
Expand All @@ -743,9 +722,7 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x XOR y element-wise.";

let description = [{
Expand All @@ -768,9 +745,7 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
//===----------------------------------------------------------------------===//
def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Elementwise Maximum";

let description = [{
Expand All @@ -794,9 +769,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
//===----------------------------------------------------------------------===//
def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Elementwise Minimum";

let description = [{
Expand All @@ -823,9 +796,11 @@ def MulOperandsAndResultElementType :
//===----------------------------------------------------------------------===//
// Operator: mul
//===----------------------------------------------------------------------===//
def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
def Tosa_MulOp : Tosa_Op<"mul", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Commutative,
MulOperandsAndResultElementType]> {
Pure]> {
let summary = "Multiplication operator";

let description = [{
Expand All @@ -846,15 +821,15 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [

let hasFolder = 1;
let hasVerifier = 1;

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}

//===----------------------------------------------------------------------===//
// Operator: pow
//===----------------------------------------------------------------------===//
def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
let summary = "Computes the power of one value to another.";

let description = [{
Expand All @@ -875,10 +850,7 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
//===----------------------------------------------------------------------===//
// Operator: sub
//===----------------------------------------------------------------------===//
def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
let summary = "Elementwise subtraction operator";

let description = [{
Expand Down Expand Up @@ -1229,9 +1201,7 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
//===----------------------------------------------------------------------===//
// Operator: select
//===----------------------------------------------------------------------===//
def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
ResultsBroadcastableShape,
SameOperandsAndResultRank]> {
def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
let summary = "Elementwise select operator";

let description = [{
Expand Down Expand Up @@ -1267,9 +1237,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
InferTensorType,
Commutative,
ResultsBroadcastableShape,
SameOperandsElementType,
SameOperandsAndResultRank]> {
SameOperandsElementType]> {
let summary = "Returns the truth value of (x == y) element-wise.";

let description = [{
Expand Down Expand Up @@ -1297,10 +1265,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
//===----------------------------------------------------------------------===//
// Operator: greater
//===----------------------------------------------------------------------===//
def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
ResultsBroadcastableShape,
SameOperandsElementType,
SameOperandsAndResultRank]> {
def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
let summary = "Returns the truth value of (x > y) element-wise.";

let description = [{
Expand All @@ -1322,11 +1287,8 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
//===----------------------------------------------------------------------===//
// Operator: greater_equal
//===----------------------------------------------------------------------===//
def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", [
ResultsBroadcastableShape,
SameOperandsElementType,
SameOperandsAndResultRank
]> {
def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
[SameOperandsElementType]> {
let summary = "Returns the truth value of (x >= y) element-wise.";

let description = [{
Expand Down
37 changes: 36 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,28 @@ LogicalResult tosa::SliceOp::verify() {
return success();
}

LogicalResult tosa::MulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
LogicalResult status = success();
llvm::SmallVector<int64_t> outShape;
if (operands.size() == 2) {
status = resolveBroadcastShape(operands, outShape);
} else {
// mul op's output shape only depend on input1 and input2, not on shift
ValueShapeRange two_inputs = operands.drop_back();
status = resolveBroadcastShape(two_inputs, outShape);
}
if (status.failed()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
} else {
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
}
return success();
}

LogicalResult tosa::MulOp::verify() {
auto resElemType = getElementTypeOrSelf(getOutput());

Expand Down Expand Up @@ -1030,6 +1052,20 @@ LogicalResult tosa::MulOp::verify() {
}
}

// check for broadcast compatible shapes in first two operands (ignoring
// shift)

// delegate function that returns shape of shaped type
auto getShape = [](const Type type) {
return mlir::cast<ShapedType>(type).getShape();
};
SmallVector<int64_t> resultShape;
if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
getShape(rankedOperandTypes[1]),
resultShape)) {
return emitOpError("operands don't have broadcast-compatible shapes");
}

return success();
}

Expand Down Expand Up @@ -1670,7 +1706,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
NARY_SHAPE_INFER(tosa::LogicalXorOp)
NARY_SHAPE_INFER(tosa::MaximumOp)
NARY_SHAPE_INFER(tosa::MinimumOp)
NARY_SHAPE_INFER(tosa::MulOp)
NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
Expand Down
18 changes: 15 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,20 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,
if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
!llvm::isa<tosa::ConstOp>(op)) {

if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
if (!llvm::isa<tosa::MulOp>(op) &&
!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
return false;

for (Value operand : op->getOperands())
for (Value operand : op->getOperands()) {
// If this is a problem in future, think about alternatives to recursion.
if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
operand == op->getOperand(2)) {
// do not recurse into MulOp's shift operand
continue;
}
if (!collectFanIn(operand.getDefiningOp(), collected))
return false;
}
}

// Insert in topological order.
Expand Down Expand Up @@ -316,14 +323,19 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
Operation *op, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
if (op->getNumResults() != 1 ||
!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
(!llvm::isa<tosa::MulOp>(op) &&
!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()))
return std::nullopt;

auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
SmallVector<Value, 3> operands;
for (Value v : op->getOperands()) {
if (valuesMap.contains(v)) {
operands.push_back(valuesMap.at(v));
} else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
v == op->getOperand(2)) {
// special case for MulOp's shift operand
operands.push_back(v);
} else {
return std::nullopt;
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> t
// -----

func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
%padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
%padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
%1 = tosa.pad %arg0, %padding : (tensor<13x21xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
return
Expand Down Expand Up @@ -211,7 +211,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
// -----

func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
%1 = tosa.pad %arg0, %0 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
return %1 : tensor<13x21x3xf32>
Expand Down Expand Up @@ -749,7 +749,7 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1

// CHECK-LABEL: test_mul_missing_shift
func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
// expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}}
// this is ok because mul's shift operand is optional for now
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
Expand Down
Loading