Skip to content

Commit 79df1c3

Browse files
authored
[mlir][tosa] Fix merge problems with mul shift (llvm#125129)
This patch fixes merge issues in TosaOpBase.td and TosaOps.td wrt traits on tosa elementwise ops and multiply op which, with the optional shift operand, is no longer strictly an elementwise op. fixed up inferReturnTypeComponents to be based on only the first two operands (ie, ignoring shift, if present) also fixed up TosaReduceTransposes to special handle tosa mul op now that it is not an elementwise op. Signed-off-by: Tai Ly <[email protected]>
1 parent b873479 commit 79df1c3

File tree

6 files changed

+94
-77
lines changed

6 files changed

+94
-77
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
239239
Tosa_Op<mnemonic, !listconcat(traits, [
240240
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
241241
["inferReturnTypeComponents"]>,
242+
ResultsBroadcastableShape,
242243
TosaElementwiseOperator,
244+
SameOperandsAndResultRank,
243245
Pure])> {
244246
let assemblyFormat =
245247
"operands attr-dict `:` functional-type(operands, results)";
@@ -248,8 +250,6 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
248250
class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
249251
Tosa_ElementwiseOp<mnemonic, !listconcat(traits, [
250252
SameOperandsAndResultShape,
251-
ResultsBroadcastableShape,
252-
SameOperandsAndResultRank,
253253
SameOperandsAndResultElementType])> {}
254254

255255
class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 30 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -482,9 +482,7 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> {
482482
//===----------------------------------------------------------------------===//
483483
def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
484484
Commutative,
485-
ResultsBroadcastableShape,
486-
SameOperandsAndResultElementType,
487-
SameOperandsAndResultRank]> {
485+
SameOperandsAndResultElementType]> {
488486
let summary = "Elementwise addition operator";
489487

490488
let description = [{
@@ -517,10 +515,8 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
517515
//===----------------------------------------------------------------------===//
518516
// Operator: arithmetic_right_shift
519517
//===----------------------------------------------------------------------===//
520-
def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
521-
ResultsBroadcastableShape,
522-
SameOperandsAndResultElementType,
523-
SameOperandsAndResultRank]> {
518+
def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
519+
[SameOperandsAndResultElementType]> {
524520
let summary = "Elementwise Arithmetic Right Shift";
525521

526522
let description = [{
@@ -544,9 +540,7 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
544540
//===----------------------------------------------------------------------===//
545541
def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
546542
Commutative,
547-
ResultsBroadcastableShape,
548-
SameOperandsAndResultElementType,
549-
SameOperandsAndResultRank]> {
543+
SameOperandsAndResultElementType]> {
550544
let summary = "Bitwise AND operator";
551545

552546
let description = [{
@@ -569,9 +563,7 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
569563
//===----------------------------------------------------------------------===//
570564
def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
571565
Commutative,
572-
ResultsBroadcastableShape,
573-
SameOperandsAndResultElementType,
574-
SameOperandsAndResultRank]> {
566+
SameOperandsAndResultElementType]> {
575567
let summary = "Bitwise OR operator";
576568

577569
let description = [{
@@ -594,9 +586,7 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
594586
//===----------------------------------------------------------------------===//
595587
def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
596588
Commutative,
597-
ResultsBroadcastableShape,
598-
SameOperandsAndResultElementType,
599-
SameOperandsAndResultRank]> {
589+
SameOperandsAndResultElementType]> {
600590
let summary = "Bitwise XOR operator";
601591

602592
let description = [{
@@ -617,10 +607,7 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
617607
//===----------------------------------------------------------------------===//
618608
// Operator: int_div
619609
//===----------------------------------------------------------------------===//
620-
def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
621-
ResultsBroadcastableShape,
622-
SameOperandsAndResultRank,
623-
SameOperandsAndResultElementType]> {
610+
def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementType]> {
624611
let summary = "Integer divide operator";
625612

626613
let description = [{
@@ -645,9 +632,7 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
645632
//===----------------------------------------------------------------------===//
646633
def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
647634
Commutative,
648-
ResultsBroadcastableShape,
649-
SameOperandsAndResultElementType,
650-
SameOperandsAndResultRank]> {
635+
SameOperandsAndResultElementType]> {
651636
let summary = "Returns the truth value of x AND y element-wise.";
652637

653638
let description = [{
@@ -668,10 +653,8 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
668653
//===----------------------------------------------------------------------===//
669654
// Operator: logical_left_shift
670655
//===----------------------------------------------------------------------===//
671-
def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
672-
ResultsBroadcastableShape,
673-
SameOperandsAndResultElementType,
674-
SameOperandsAndResultRank]> {
656+
def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
657+
[SameOperandsAndResultElementType]> {
675658
let summary = "Elementwise Logical Left Shift";
676659

677660
let description = [{
@@ -692,10 +675,8 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
692675
//===----------------------------------------------------------------------===//
693676
// Operator: logical_right_shift
694677
//===----------------------------------------------------------------------===//
695-
def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
696-
ResultsBroadcastableShape,
697-
SameOperandsAndResultElementType,
698-
SameOperandsAndResultRank]> {
678+
def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
679+
[SameOperandsAndResultElementType]> {
699680
let summary = "Elementwise Logical Right Shift";
700681

701682
let description = [{
@@ -718,9 +699,7 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
718699
//===----------------------------------------------------------------------===//
719700
def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
720701
Commutative,
721-
ResultsBroadcastableShape,
722-
SameOperandsAndResultElementType,
723-
SameOperandsAndResultRank]> {
702+
SameOperandsAndResultElementType]> {
724703
let summary = "Returns the truth value of x OR y element-wise.";
725704

726705
let description = [{
@@ -743,9 +722,7 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
743722
//===----------------------------------------------------------------------===//
744723
def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
745724
Commutative,
746-
ResultsBroadcastableShape,
747-
SameOperandsAndResultElementType,
748-
SameOperandsAndResultRank]> {
725+
SameOperandsAndResultElementType]> {
749726
let summary = "Returns the truth value of x XOR y element-wise.";
750727

751728
let description = [{
@@ -768,9 +745,7 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
768745
//===----------------------------------------------------------------------===//
769746
def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
770747
Commutative,
771-
ResultsBroadcastableShape,
772-
SameOperandsAndResultElementType,
773-
SameOperandsAndResultRank]> {
748+
SameOperandsAndResultElementType]> {
774749
let summary = "Elementwise Maximum";
775750

776751
let description = [{
@@ -794,9 +769,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
794769
//===----------------------------------------------------------------------===//
795770
def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
796771
Commutative,
797-
ResultsBroadcastableShape,
798-
SameOperandsAndResultElementType,
799-
SameOperandsAndResultRank]> {
772+
SameOperandsAndResultElementType]> {
800773
let summary = "Elementwise Minimum";
801774

802775
let description = [{
@@ -823,9 +796,11 @@ def MulOperandsAndResultElementType :
823796
//===----------------------------------------------------------------------===//
824797
// Operator: mul
825798
//===----------------------------------------------------------------------===//
826-
def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
799+
def Tosa_MulOp : Tosa_Op<"mul", [
800+
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
801+
["inferReturnTypeComponents"]>,
827802
Commutative,
828-
MulOperandsAndResultElementType]> {
803+
Pure]> {
829804
let summary = "Multiplication operator";
830805

831806
let description = [{
@@ -846,15 +821,15 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
846821

847822
let hasFolder = 1;
848823
let hasVerifier = 1;
824+
825+
let assemblyFormat =
826+
"operands attr-dict `:` functional-type(operands, results)";
849827
}
850828

851829
//===----------------------------------------------------------------------===//
852830
// Operator: pow
853831
//===----------------------------------------------------------------------===//
854-
def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
855-
ResultsBroadcastableShape,
856-
SameOperandsAndResultElementType,
857-
SameOperandsAndResultRank]> {
832+
def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
858833
let summary = "Computes the power of one value to another.";
859834

860835
let description = [{
@@ -875,10 +850,7 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
875850
//===----------------------------------------------------------------------===//
876851
// Operator: sub
877852
//===----------------------------------------------------------------------===//
878-
def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [
879-
ResultsBroadcastableShape,
880-
SameOperandsAndResultElementType,
881-
SameOperandsAndResultRank]> {
853+
def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
882854
let summary = "Elementwise subtraction operator";
883855

884856
let description = [{
@@ -1229,9 +1201,7 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
12291201
//===----------------------------------------------------------------------===//
12301202
// Operator: select
12311203
//===----------------------------------------------------------------------===//
1232-
def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
1233-
ResultsBroadcastableShape,
1234-
SameOperandsAndResultRank]> {
1204+
def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
12351205
let summary = "Elementwise select operator";
12361206

12371207
let description = [{
@@ -1267,9 +1237,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
12671237
def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
12681238
InferTensorType,
12691239
Commutative,
1270-
ResultsBroadcastableShape,
1271-
SameOperandsElementType,
1272-
SameOperandsAndResultRank]> {
1240+
SameOperandsElementType]> {
12731241
let summary = "Returns the truth value of (x == y) element-wise.";
12741242

12751243
let description = [{
@@ -1297,10 +1265,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
12971265
//===----------------------------------------------------------------------===//
12981266
// Operator: greater
12991267
//===----------------------------------------------------------------------===//
1300-
def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
1301-
ResultsBroadcastableShape,
1302-
SameOperandsElementType,
1303-
SameOperandsAndResultRank]> {
1268+
def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
13041269
let summary = "Returns the truth value of (x > y) element-wise.";
13051270

13061271
let description = [{
@@ -1322,11 +1287,8 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
13221287
//===----------------------------------------------------------------------===//
13231288
// Operator: greater_equal
13241289
//===----------------------------------------------------------------------===//
1325-
def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", [
1326-
ResultsBroadcastableShape,
1327-
SameOperandsElementType,
1328-
SameOperandsAndResultRank
1329-
]> {
1290+
def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
1291+
[SameOperandsElementType]> {
13301292
let summary = "Returns the truth value of (x >= y) element-wise.";
13311293

13321294
let description = [{

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,28 @@ LogicalResult tosa::SliceOp::verify() {
958958
return success();
959959
}
960960

961+
LogicalResult tosa::MulOp::inferReturnTypeComponents(
962+
MLIRContext *context, ::std::optional<Location> location,
963+
ValueShapeRange operands, DictionaryAttr attributes,
964+
OpaqueProperties properties, RegionRange regions,
965+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
966+
LogicalResult status = success();
967+
llvm::SmallVector<int64_t> outShape;
968+
if (operands.size() == 2) {
969+
status = resolveBroadcastShape(operands, outShape);
970+
} else {
971+
// mul op's output shape only depend on input1 and input2, not on shift
972+
ValueShapeRange two_inputs = operands.drop_back();
973+
status = resolveBroadcastShape(two_inputs, outShape);
974+
}
975+
if (status.failed()) {
976+
inferredReturnShapes.push_back(ShapedTypeComponents());
977+
} else {
978+
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
979+
}
980+
return success();
981+
}
982+
961983
LogicalResult tosa::MulOp::verify() {
962984
auto resElemType = getElementTypeOrSelf(getOutput());
963985

@@ -1030,6 +1052,20 @@ LogicalResult tosa::MulOp::verify() {
10301052
}
10311053
}
10321054

1055+
// check for broadcast compatible shapes in first two operands (ignoring
1056+
// shift)
1057+
1058+
// delegate function that returns shape of shaped type
1059+
auto getShape = [](const Type type) {
1060+
return mlir::cast<ShapedType>(type).getShape();
1061+
};
1062+
SmallVector<int64_t> resultShape;
1063+
if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
1064+
getShape(rankedOperandTypes[1]),
1065+
resultShape)) {
1066+
return emitOpError("operands don't have broadcast-compatible shapes");
1067+
}
1068+
10331069
return success();
10341070
}
10351071

@@ -1670,7 +1706,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
16701706
NARY_SHAPE_INFER(tosa::LogicalXorOp)
16711707
NARY_SHAPE_INFER(tosa::MaximumOp)
16721708
NARY_SHAPE_INFER(tosa::MinimumOp)
1673-
NARY_SHAPE_INFER(tosa::MulOp)
16741709
NARY_SHAPE_INFER(tosa::NegateOp)
16751710
NARY_SHAPE_INFER(tosa::PowOp)
16761711
NARY_SHAPE_INFER(tosa::ReciprocalOp)

mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,13 +281,20 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,
281281
if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
282282
!llvm::isa<tosa::ConstOp>(op)) {
283283

284-
if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
284+
if (!llvm::isa<tosa::MulOp>(op) &&
285+
!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
285286
return false;
286287

287-
for (Value operand : op->getOperands())
288+
for (Value operand : op->getOperands()) {
288289
// If this is a problem in future, think about alternatives to recursion.
290+
if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
291+
operand == op->getOperand(2)) {
292+
// do not recurse into MulOp's shift operand
293+
continue;
294+
}
289295
if (!collectFanIn(operand.getDefiningOp(), collected))
290296
return false;
297+
}
291298
}
292299

293300
// Insert in topological order.
@@ -316,14 +323,19 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
316323
Operation *op, const DenseMap<Value, Value> &valuesMap,
317324
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
318325
if (op->getNumResults() != 1 ||
319-
!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
326+
(!llvm::isa<tosa::MulOp>(op) &&
327+
!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()))
320328
return std::nullopt;
321329

322330
auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
323331
SmallVector<Value, 3> operands;
324332
for (Value v : op->getOperands()) {
325333
if (valuesMap.contains(v)) {
326334
operands.push_back(valuesMap.at(v));
335+
} else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
336+
v == op->getOperand(2)) {
337+
// special case for MulOp's shift operand
338+
operands.push_back(v);
327339
} else {
328340
return std::nullopt;
329341
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> t
183183
// -----
184184

185185
func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
186-
%padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
186+
%padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
187187
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
188188
%1 = tosa.pad %arg0, %padding : (tensor<13x21xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
189189
return
@@ -211,7 +211,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
211211
// -----
212212

213213
func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
214-
%0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
214+
%0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
215215
// expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
216216
%1 = tosa.pad %arg0, %0 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
217217
return %1 : tensor<13x21x3xf32>
@@ -749,7 +749,7 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1
749749

750750
// CHECK-LABEL: test_mul_missing_shift
751751
func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
752-
// expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}}
752+
// this is ok because mul's shift operand is optional for now
753753
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
754754
return %0 : tensor<13x21x3xi32>
755755
}

0 commit comments

Comments
 (0)