Skip to content

Commit dc1a1db

Browse files
committed
[AutoBump] Merge with fixes of 519eef3 (Oct 22)
2 parents 1b28c53 + 519eef3 commit dc1a1db

File tree

5 files changed

+27
-11
lines changed

5 files changed

+27
-11
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
815815
);
816816

817817
let hasFolder = 1;
818+
let hasVerifier = 1;
818819
}
819820

820821
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,17 @@ static Value createLinalgBodyCalculationForElementwiseOp(
8080
if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
8181
return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
8282

83-
// tosa::MulOp
84-
if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
85-
if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
86-
(void)rewriter.notifyMatchFailure(op,
87-
"Cannot have shift value for float");
88-
return nullptr;
89-
}
90-
return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
91-
}
92-
9383
// tosa::DivOp
9484
if (isa<tosa::IntDivOp>(op)) {
9585
if (elementTy.isSignlessInteger())
9686
return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
9787
else
9888
return rewriter.create<arith::DivUIOp>(loc, resultTypes, args);
9989
}
90+
91+
// tosa::IntDivOp
92+
if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
93+
return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
10094

10195
// tosa::ReciprocalOp
10296
if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
@@ -105,6 +99,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
10599
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
106100
}
107101

102+
// tosa::MulOp
103+
if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
104+
return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
105+
108106
if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
109107
Value a = args[0];
110108
Value b = args[1];

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,14 @@ LogicalResult tosa::SliceOp::verify() {
927927
return success();
928928
}
929929

930+
LogicalResult tosa::MulOp::verify() {
931+
Type elementTy = getInput1().getType().getElementType();
932+
if (isa<FloatType>(elementTy) && getShift() != 0)
933+
return emitOpError() << "require shift to be 0 for float type";
934+
935+
return success();
936+
}
937+
930938
LogicalResult tosa::TableOp::inferReturnTypeComponents(
931939
MLIRContext *context, ::std::optional<Location> location,
932940
TableOp::Adaptor adaptor,

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
657657
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
658658
return %0 : tensor<1x32x32x16xf32>
659659
}
660+
661+
// -----
662+
663+
// CHECK-LABEL: test_mul_invalid_shift
664+
func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
665+
// expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
666+
%0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
667+
return %0 : tensor<13x21x3xf32>
668+
}

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
315315
// -----
316316
// CHECK-LABEL: mul
317317
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
318-
%0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
318+
%0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
319319
return %0 : tensor<13x21x3xf32>
320320
}
321321

0 commit comments

Comments
 (0)