From 56aec5788295f83816e3ce5643f7cf7bb7b724bb Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 27 Nov 2024 10:50:11 +0000 Subject: [PATCH] [mlir][tosa] Switch zero point of negate to input variable type This commit changes the zero point attribute to an input to align with the 1.0 spec. Change-Id: Ibc9e5959b36c182a9e0c5c23a2f9d42a572a1184 Signed-off-by: Tai Ly --- .../Dialect/Tosa/IR/TosaComplianceData.h.inc | 10 +- .../mlir/Dialect/Tosa/IR/TosaOpBase.td | 6 +- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 21 +++- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 29 +++-- .../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp | 42 +++++++- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 31 +++++- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 101 +++++++++++++++--- .../TosaToLinalg/tosa-to-linalg.mlir | 50 ++++----- .../Dialect/Mesh/sharding-propagation.mlir | 13 ++- mlir/test/Dialect/Mesh/spmdization.mlir | 15 ++- mlir/test/Dialect/Tosa/availability.mlir | 4 +- mlir/test/Dialect/Tosa/canonicalize.mlir | 45 +++++++- mlir/test/Dialect/Tosa/invalid.mlir | 64 +++++++++++ mlir/test/Dialect/Tosa/level_check.mlir | 4 +- mlir/test/Dialect/Tosa/ops.mlir | 4 +- mlir/test/Dialect/Tosa/quant-test.mlir | 4 +- mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 12 ++- 17 files changed, 364 insertions(+), 91 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index d3fd4c3d1d3e1..efc329ee48849 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -114,8 +114,12 @@ profileComplianceMap = { {"tosa.logical_not", {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}}, {"tosa.negate", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{i8T, i8T, i8T, i8T}, + {i16T, i16T, i16T, i16T}, + {i32T, i32T, i32T, i32T}}}, + {{Profile::pro_fp}, + {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}}, {"tosa.reciprocal", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, @@ -310,7 +314,7 @@ extensionComplianceMap = { {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}}, {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index f2328003e49c5..da4daa03aa652 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -178,13 +178,13 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder< input, kernel, stride, pad, acc_type); }]>; -// This builder is called on single-parameter unary operators that have a scale +// This builder is called on single-parameter negate operators that have a scale // relationship between their input and output, expressed by the // UnaryOpQuantizationAttr. -def Tosa_UnaryOpQuantInfoBuilder : OpBuilder< +def Tosa_NegateOpQuantInfoBuilder : OpBuilder< (ins "Type":$outputType, "Value":$input), [{ - buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input); + buildNegateOpWithQuantInfo($_builder, $_state, outputType, input); }]>; // These builders are called on the TOSA pad operator that needs to create its diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index ecddc9fe9a13d..52c80e975f290 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1356,7 +1356,9 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> { //===----------------------------------------------------------------------===// // Operator: negate //===----------------------------------------------------------------------===// -def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> { +def Tosa_NegateOp : Tosa_InferShapedTypeOp<"negate", [ + TosaElementwiseOperator, + Pure]> { let summary = "Elementwise negate op"; let description = [{ @@ -1365,8 +1367,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> { let arguments = (ins Tosa_Tensor:$input1, - OptionalAttr:$input1_zp, - OptionalAttr:$output_zp + Tosa_ScalarIntOrFloatTensor:$input1_zp, + Tosa_ScalarIntOrFloatTensor:$output_zp ); let results = (outs @@ -1378,9 +1380,20 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> { Extension<[Tosa_EXT_BF16]>, ]; - let builders = [Tosa_UnaryOpQuantInfoBuilder]; + let builders = [Tosa_NegateOpQuantInfoBuilder]; + + let extraClassDeclaration = [{ + FailureOr getInput1ZeroPoint(); + FailureOr getOutputZeroPoint(); + LogicalResult verifyInput1ZeroPoint(int64_t zp); + LogicalResult verifyOutputZeroPoint(int64_t zp); + }]; let hasFolder = 1; + let hasVerifier = 1; + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index f7dd33c7e8b53..7772c186b526a 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -193,18 +193,29 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::NegateOp if (isa(op)) { - if (isa(elementTy)) - return rewriter.create(loc, resultTypes, args); + auto negate = cast(op); - if (isa(elementTy)) { - auto inputZpAttr = cast(op).getInput1ZpAttr(); - auto outputZpAttr = cast(op).getOutputZpAttr(); + FailureOr maybeInZp = negate.getInput1ZeroPoint(); + if (failed(maybeInZp)) { + (void)rewriter.notifyMatchFailure( + op, "input1 zero point cannot be statically determined"); + return nullptr; + } + + FailureOr maybeOutZp = negate.getOutputZeroPoint(); + if (failed(maybeOutZp)) { + (void)rewriter.notifyMatchFailure( + op, "output zero point cannot be statically determined"); + return nullptr; + } - const int64_t inZp = - inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0; - const int64_t outZp = - outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0; + int64_t inZp = *maybeInZp; + int64_t outZp = *maybeOutZp; + if (isa(elementTy)) + return rewriter.create(loc, resultTypes, args[0]); + + if (isa(elementTy)) { if (!inZp && !outZp) { auto constant = rewriter.create( loc, IntegerAttr::get(elementTy, 0)); diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp index 6dcb7c845b21f..be29298a35aeb 100644 --- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp @@ -62,6 +62,45 @@ struct MatMulOpSharding } }; +struct NegateOpSharding + : public ShardingInterface::ExternalModel { + SmallVector getLoopIteratorTypes(Operation *op) const { + Value val = op->getOperand(0); + auto type = dyn_cast(val.getType()); + if (!type) + return {}; + SmallVector types(type.getRank(), + utils::IteratorType::parallel); + return types; + } + + SmallVector getIndexingMaps(Operation *op) const { + MLIRContext *ctx = op->getContext(); + Value val = op->getOperand(0); + auto type = dyn_cast(val.getType()); + if (!type) + return {}; + int64_t rank = type.getRank(); + SmallVector maps = { + AffineMap::getMultiDimIdentityMap(rank, ctx), + AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx), + AffineMap::getMultiDimIdentityMap(rank, ctx)}; + return maps; + } + + LogicalResult spmdize(Operation *op, ArrayRef spmdizedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &spmdizationMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { + spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings, + resultShardings, spmdizationMap, + symbolTable, builder); + return success(); + } +}; + template static void registerElemwiseOne(MLIRContext *ctx) { OpType::template attachInterface>(*ctx); @@ -84,9 +123,10 @@ void mlir::tosa::registerShardingInterfaceExternalModels( BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, - LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, + LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp, GreaterEqualOp>(ctx); MatMulOp::attachInterface(*ctx); + NegateOp::attachInterface(*ctx); }); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 3e99c1f717d09..09d2c5d35263c 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1143,13 +1143,36 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) { } OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) { - auto input = getInput1(); // Element-wise negate(negate(x)) = x - if (auto op = input.getDefiningOp()) { - return op.getInput1(); + // iff all zero points are constant 0 + auto definingOp = getInput1().getDefiningOp(); + if (!definingOp) { + // defining op of input1 is not a negate, cannot fold + return {}; } - return {}; + if (FailureOr maybeIZp = getInput1ZeroPoint(); + failed(maybeIZp) || *maybeIZp != 0) { + // input1 zero point is not constant 0, cannot fold + return {}; + } + if (FailureOr maybeOZp = getOutputZeroPoint(); + failed(maybeOZp) || *maybeOZp != 0) { + // output zero point is not constant 0, cannot fold + return {}; + } + if (FailureOr maybeIZp = definingOp.getInput1ZeroPoint(); + failed(maybeIZp) || *maybeIZp != 0) { + // definingOp's input1 zero point is not constant 0, cannot fold + return {}; + } + if (FailureOr maybeOZp = definingOp.getOutputZeroPoint(); + failed(maybeOZp) || *maybeOZp != 0) { + // definingOp's output zero point is not constant 0, cannot fold + return {}; + } + + return definingOp.getInput1(); } OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 7a991b3876f69..219775c31bd56 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -697,23 +697,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, result.types.push_back(outputType); } -/// This builder is called on single-parameter unary operators that have scale -/// relationship between their input and output, expressed by the -/// UnaryOpQuantizationAttr. -static void buildUnaryOpWithQuantInfo(OpBuilder &builder, - OperationState &result, Type outputType, - Value input) { - result.addOperands(input); +/// This builder is called on single-parameter negate operator +/// to construct input and output zero points based on their +/// types. +static void buildNegateOpWithQuantInfo(OpBuilder &builder, + OperationState &result, Type outputType, + Value input) { + const Location loc{result.location}; + int64_t input1Zp{0}; + int64_t outputZp{0}; auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); if (quantAttr) { - // note: negateOp has attributes input1_zp and output_zp - result.addAttribute("input1_zp", - builder.getI32IntegerAttr( - static_cast(quantAttr.getInputZp()))); - result.addAttribute("output_zp", - builder.getI32IntegerAttr( - static_cast(quantAttr.getOutputZp()))); + input1Zp = quantAttr.getInputZp(); + outputZp = quantAttr.getOutputZp(); + } + const std::optional input1ZpOp = + createZeroPointTensor(builder, loc, input.getType(), input1Zp); + if (!input1ZpOp) { + (void)emitError( + loc, "Failed to create input1 zero point for quantized NEGATE op"); + } + + const std::optional outputZpOp = + createZeroPointTensor(builder, loc, input.getType(), outputZp); + if (!outputZpOp) { + (void)emitError( + loc, "Failed to create output zero point for quantized NEGATE op"); } + + if (input1ZpOp && outputZpOp) { + result.addOperands({input, input1ZpOp.value(), outputZpOp.value()}); + } else { + // failed to create one or more zero points above: just add input as + // operands. This will trigger error in building the op because of + // missing zero points + result.addOperands({input}); + } + result.types.push_back(outputType); } @@ -1728,6 +1748,9 @@ ZERO_POINT_HELPER(AvgPool2dOp, Input) ZERO_POINT_HELPER(AvgPool2dOp, Output) ZERO_POINT_HELPER(MatMulOp, A) ZERO_POINT_HELPER(MatMulOp, B) +ZERO_POINT_HELPER(NegateOp, Input1) +ZERO_POINT_HELPER(NegateOp, Output) + #undef ZERO_POINT_HELPER LogicalResult tosa::TransposeOp::inferReturnTypeComponents( @@ -2230,7 +2253,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) NARY_SHAPE_INFER(tosa::LogicalXorOp) NARY_SHAPE_INFER(tosa::MaximumOp) NARY_SHAPE_INFER(tosa::MinimumOp) -NARY_SHAPE_INFER(tosa::NegateOp) NARY_SHAPE_INFER(tosa::PowOp) NARY_SHAPE_INFER(tosa::ReciprocalOp) NARY_SHAPE_INFER(tosa::ReverseOp) @@ -2243,6 +2265,55 @@ NARY_SHAPE_INFER(tosa::ErfOp) NARY_SHAPE_INFER(tosa::SigmoidOp) #undef PRED_SHAPE_INFER +LogicalResult tosa::NegateOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + NegateOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + ShapeAdaptor inputShape(adaptor.getInput1().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + return success(); +} + +LogicalResult tosa::NegateOp::verify() { + // Verify same element type + const Type input1Type = getInput1().getType(); + const Type outputType = getOutput().getType(); + if (verifySameElementTypes(*this, input1Type, outputType).failed()) + return failure(); + + // Verify same shape + const SmallVector types = {input1Type, outputType}; + if (failed(verifyCompatibleShapes(types))) + return emitOpError() << "requires the same shape for input1 and output"; + + const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType()); + const Type input1ZpEType = + getStorageElementTypeOrSelf(getInput1Zp().getType()); + if (input1EType != input1ZpEType) { + return emitOpError("expect both input1 and its zero point are the same " + "element type, got ") + << input1EType << " and " << input1ZpEType; + } + const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType()); + const Type outputZpEType = + getStorageElementTypeOrSelf(getOutputZp().getType()); + if (outputEType != outputZpEType) { + return emitOpError("expect both output and its zero point are the same " + "element type, got ") + << outputEType << " and " << outputZpEType; + } + + FailureOr maybeIZp = getInput1ZeroPoint(); + if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed()) + return failure(); + + FailureOr maybeOZp = getOutputZeroPoint(); + if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed()) + return failure(); + + return success(); +} + static LogicalResult poolingInferReturnTypes( ShapeAdaptor inputShape, ArrayRef kernel, ArrayRef stride, ArrayRef pad, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index a3ed8c2805282..1c7be0ed6f107 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -477,7 +477,9 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () { // CHECK: linalg.generic // CHECK: arith.negf - %5 = tosa.negate %0 : (tensor<1xf32>) -> tensor<1xf32> + %in_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %out_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %5 = tosa.negate %0, %in_zp, %out_zp : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: pow @@ -662,10 +664,12 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns %40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic - // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32): + // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32): // CHECK: [[ZERO:%.+]] = arith.constant 0 // CHECK: arith.subi [[ZERO]], %[[ARG1]] - %5 = tosa.negate %arg0 : (tensor<1xi32>) -> tensor<1xi32> + %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: and @@ -852,40 +856,22 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () { // CHECK-LABEL: @test_negate_quantized func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () { // CHECK: linalg.generic - // CHECK: ^bb0(%[[BBARG0:.+]]: i8, - // CHECK: [[ZERO:%.+]] = arith.constant 0 - // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], %[[BBARG0]] - // CHECK: linalg.yield [[SUB]] - %0 = tosa.negate %arg0 {input_zp1 = 0 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8> - - // CHECK: linalg.generic - // CHECK: ^bb0(%[[BBARG0:.+]]: i8, - // CHECK: [[C32639:%.+]] = arith.constant 32639 + // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8 + // CHECK: [[CNST:%.+]] = arith.constant 7 // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16 - // CHECK: [[SUB:%.+]] = arith.subi [[C32639]], [[EXT]] - // CHECK: [[MIN:%.+]] = arith.constant -128 - // CHECK: [[MAX:%.+]] = arith.constant 127 - // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]] - // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]] - // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]] - // CHECK: linalg.yield [[TRUNC]] - %1 = tosa.negate %arg0 {input1_zp = 32639 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8> - - // CHECK: linalg.generic - // CHECK: ^bb0(%[[BBARG0:.+]]: i8, - // CHECK: [[C32640:%.+]] = arith.constant 32640 - // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32 - // CHECK: [[SUB:%.+]] = arith.subi [[C32640]], [[EXT]] + // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]] // CHECK: [[MIN:%.+]] = arith.constant -128 // CHECK: [[MAX:%.+]] = arith.constant 127 // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]] // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]] // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]] // CHECK: linalg.yield [[TRUNC]] - %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8> + %in_zp0 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %out_zp0 = "tosa.const"() <{values = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8> // CHECK: linalg.generic - // CHECK: ^bb0(%[[BBARG0:.+]]: i8, + // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8 // CHECK: [[C_128:%.+]] = arith.constant -128 // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16 // CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]] @@ -895,14 +881,18 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () { // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]] // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]] // CHECK: linalg.yield [[TRUNC]] - %3 = tosa.negate %arg0 {input1_zp = -128 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8> + %in_zp3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> + %out_zp3 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = tosa.negate %arg0, %in_zp3, %out_zp3 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8> // CHECK: linalg.generic // CHECK: ^bb0(%[[BBARG0:.+]]: i8, // CHECK: [[ZERO:%.+]] = arith.constant 0 // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], // CHECK: linalg.yield [[SUB]] - %4 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant} : (tensor<1xi8>) -> tensor<1xi8> + %in_zp4 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %out_zp4 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %4 = tosa.negate %arg0, %in_zp4, %out_zp4 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8> return } diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir index 14c67e670e921..aa5fa00488f08 100644 --- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir +++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir @@ -77,7 +77,7 @@ func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf // CHECK-LABEL: func.func @arrow_structure // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32> -func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) { +func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) { // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32> // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]] @@ -85,12 +85,15 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32> // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]] - // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]] : tensor<8x16xf32> - %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]] + // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]] : tensor<8x16xf32> + %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding + // CHECK-NEXT: %[[ZP1:.*]] = mesh.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[ZP2:.*]] = mesh.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]] // CHECK-NEXT: %[[S8:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]] : tensor<8x16xf32> - %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32> %s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding %3 = mesh.shard %2 to %s3 : tensor<8x16xf32> // CHECK-NEXT: return %[[V6]], %[[V8]] diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir index 59f7162e21013..5c9fd29444f04 100644 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ b/mlir/test/Dialect/Mesh/spmdization.mlir @@ -176,7 +176,7 @@ func.func @multiple_chained_ops( %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8> // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8> %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8> - // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 : + // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 : // CHECK-SAME: tensor<2xi8> -> tensor<1xi8> %s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding %6 = mesh.shard %5 to %s6 : tensor<2xi8> @@ -207,7 +207,11 @@ mesh.mesh @mesh_1d_4(shape = 4) // CHECK-LABEL: func @ew_chain_with_halo func.func @ew_chain_with_halo( // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32> - %arg0: tensor<8x16xf32>) + %arg0: tensor<8x16xf32>, + // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32> + %arg1: tensor<1xf32>, + // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32> + %arg2: tensor<1xf32>) // CHECK-SAME: -> tensor<5x16xf32> -> tensor<8x16xf32> { %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding @@ -224,8 +228,11 @@ func.func @ew_chain_with_halo( %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2 : tensor<8x16xf32> %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4 annotate_for_users : tensor<8x16xf32> - // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]] : (tensor<5x16xf32>) -> tensor<5x16xf32> - %2 = tosa.negate %sharding_annotated_4 : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32> + %sharding_1 = mesh.sharding @mesh_1d_4 split_axes = [[]] : !mesh.sharding + %zero_point_1 = mesh.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32> + %zero_point_2 = mesh.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32> + %2 = tosa.negate %sharding_annotated_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32> %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5 : tensor<8x16xf32> %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index b786264d84106..820ea6559b848 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -380,7 +380,9 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> { func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK: profiles: [ [pro_int, pro_fp] ] // CHECK: extensions: [ [bf16] ] - %0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.negate %arg0, %input_zp, %output_zp : (tensor<13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 4242f68609634..e2575c764fdfe 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -856,13 +856,54 @@ func.func @fold_exp_log(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_negate_negate func.func @fold_negate_negate(%arg0: tensor) -> tensor { // CHECK: return %arg{{.*}} : tensor - %0 = tosa.negate %arg0 : (tensor) -> tensor - %1 = tosa.negate %0 : (tensor) -> tensor + %in_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %out_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.negate %arg0, %in_zp, %out_zp : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + %1 = tosa.negate %0, %in_zp, %out_zp : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor return %1 : tensor } // ----- +// CHECK-LABEL: @no_fold_negate_negate_non_const_zp +func.func @no_fold_negate_negate_non_const_zp(%arg0: tensor, %in_zp: tensor<1xf32>) -> tensor { + // cannot fold if any zp is not constant + // CHECK: tosa.negate + // CHECK: tosa.negate + // CHECK: tosa.negate + // CHECK: tosa.negate + // CHECK: tosa.negate + %zero = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.negate %arg0, %in_zp, %zero : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + %1 = tosa.negate %0, %zero, %zero : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + %2 = tosa.negate %1, %zero, %in_zp : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + %3 = tosa.negate %2, %zero, %zero : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + %4 = tosa.negate %3, %in_zp, %zero : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + return %4 : tensor +} + +// ----- + +// CHECK-LABEL: @no_fold_negate_negate_non_zero_zp +func.func @no_fold_negate_negate_non_zero_zp(%arg0: tensor) -> tensor { + // cannot fold if any zp is not constant 0 + // CHECK: tosa.negate + // CHECK: tosa.negate + // CHECK: tosa.negate + // CHECK: tosa.negate + // CHECK: tosa.negate + %zero = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %one = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.negate %arg0, %zero, %one : (tensor, tensor<1xi8>, tensor<1xi8>) -> tensor + %1 = tosa.negate %0, %zero, %zero : (tensor, tensor<1xi8>, tensor<1xi8>) -> tensor + %2 = tosa.negate %1, %one, %zero : (tensor, tensor<1xi8>, tensor<1xi8>) -> tensor + %3 = tosa.negate %2, %zero, %zero : (tensor, tensor<1xi8>, tensor<1xi8>) -> tensor + %4 = tosa.negate %3, %zero, %one : (tensor, tensor<1xi8>, tensor<1xi8>) -> tensor + return %4 : tensor +} + +// ----- + // CHECK-LABEL: @fold_abs_abs func.func @fold_abs_abs(%arg0: tensor) -> tensor { // CHECK: %[[ABS:.*]] = tosa.abs %arg{{.*}} : (tensor) -> tensor diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index f536444f6379e..c4fe9c1a6cabc 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1652,3 +1652,67 @@ func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1 %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> return %0 : tensor<1x14x28xf32> } + +// ----- + +// CHECK-LABEL: test_negate_same_element_type +func.func @test_negate_same_element_type(%arg0: tensor<1x16x16x8xf16>, %arg1: tensor<1xf16>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf32> { + // expected-error@+1 {{'tosa.negate' op expect input and output to have same element type, got 'f16' and 'f32'}} + %0 = tosa.negate %arg0, %arg1, %arg2 + : (tensor<1x16x16x8xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x16x16x8xf32> + return %0 : tensor<1x16x16x8xf32> +} + +// ----- + +// CHECK-LABEL: test_negate_same_shape +func.func @test_negate_same_shape(%arg0: tensor<1x16x16x16xf16>, %arg1: tensor<1xf16>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf16> { + // expected-error@+1 {{'tosa.negate' op requires the same shape for input1 and output}} + %0 = tosa.negate %arg0, %arg1, %arg2 + : (tensor<1x16x16x16xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x16x16x8xf16> + return %0 : tensor<1x16x16x8xf16> +} + +// ----- + +// CHECK-LABEL: test_negate_input_zp_same_element_type +func.func @test_negate_input_zp_same_element_type(%arg0: tensor<1x16x16x8xf16>, %arg1: tensor<1xi8>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf16> { + // expected-error@+1 {{'tosa.negate' op expect both input1 and its zero point are the same element type, got 'f16' and 'i8'}} + %0 = tosa.negate %arg0, %arg1, %arg2 + : (tensor<1x16x16x8xf16>, tensor<1xi8>, tensor<1xf16>) -> tensor<1x16x16x8xf16> + return %0 : tensor<1x16x16x8xf16> +} + +// ----- + +// CHECK-LABEL: test_negate_output_zp_same_element_type +func.func @test_negate_output_zp_same_element_type(%arg0: tensor<1x16x16x8xi8>, %arg1: tensor<1xi8>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xi8> { + // expected-error@+1 {{'tosa.negate' op expect both output and its zero point are the same element type, got 'i8' and 'f16'}} + %0 = tosa.negate %arg0, %arg1, %arg2 + : (tensor<1x16x16x8xi8>, tensor<1xi8>, tensor<1xf16>) -> tensor<1x16x16x8xi8> + return %0 : tensor<1x16x16x8xi8> +} + +// ----- + +// CHECK-LABEL: test_negate_input_zp_non_zero +func.func @test_negate_input_zp_non_zero(%arg0: tensor<1x16x16x8xf32>) -> tensor<1x16x16x8xf32> { + %input_zp = "tosa.const"() {values = dense<-1.0> : tensor<1xf32>} : () -> tensor<1xf32> + %output_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.negate' op input1 zero point must be zero for non-int8 integer types}} + %0 = tosa.negate %arg0, %input_zp, %output_zp + : (tensor<1x16x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x16x16x8xf32> + return %0 : tensor<1x16x16x8xf32> +} + +// ----- + +// CHECK-LABEL: test_negate_output_zp_non_zero +func.func @test_negate_output_zp_non_zero(%arg0: tensor<1x16x16x8xf32>) -> tensor<1x16x16x8xf32> { + %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %output_zp = "tosa.const"() {values = dense<-1.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.negate' op output zero point must be zero for non-int8 integer types}} + %0 = tosa.negate %arg0, %input_zp, %output_zp + : (tensor<1x16x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x16x16x8xf32> + return %0 : tensor<1x16x16x8xf32> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 6d8237635d0ec..2f7250dabe162 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -249,9 +249,9 @@ func.func @test_logical_not_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xi1>) -> te // ----- -func.func @test_negate_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { +func.func @test_negate_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x1x13x21x3xf32> { // expected-error@+1 {{'tosa.negate' op failed level check: operand rank(shape) <= MAX_RANK}} - %0 = tosa.negate %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> + %0 = tosa.negate %arg0, %arg1, %arg1 : (tensor<1x1x1x1x13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x1x13x21x3xf32> return %0 : tensor<1x1x1x1x13x21x3xf32> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 480d8c327ab83..916886025cb0e 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -487,8 +487,8 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> { // ----- // CHECK-LABEL: negate -func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +func.func @test_negate(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir index 447a6ef7f9e05..f0ad4eb4fdb0b 100644 --- a/mlir/test/Dialect/Tosa/quant-test.mlir +++ b/mlir/test/Dialect/Tosa/quant-test.mlir @@ -2,9 +2,9 @@ // ----- // CHECK-LABEL: test_build_qtype -func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> { +func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> { // CHECK: tosa.negate - %0 = "tosa.negate"(%arg0) : (tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> + %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> return %0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index deede4b0afadc..3d0ded8c58ac5 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -45,8 +45,10 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () { // CHECK: tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32> %5 = tosa.log %arg0 : (tensor<4xf32>) -> tensor<*xf32> - // CHECK: tosa.negate %arg0 : (tensor<4xf32>) -> tensor<4xf32> - %6 = tosa.negate %arg0 : (tensor<4xf32>) -> tensor<*xf32> + %in_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %out_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + // CHECK: tosa.negate %arg0, {{.+}} : (tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4xf32> + %6 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.reciprocal %arg0 : (tensor<4xf32>) -> tensor<4xf32> %7 = tosa.reciprocal %arg0 : (tensor<4xf32>) -> tensor<*xf32> @@ -87,8 +89,10 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<2xi8>) -> () { // CHECK: tosa.clz %arg0 : (tensor<4xi32>) -> tensor<4xi32> %3 = tosa.clz %arg0 : (tensor<4xi32>) -> tensor<*xi32> - // CHECK: tosa.negate %arg0 : (tensor<4xi32>) -> tensor<4xi32> - %4 = tosa.negate %arg0 : (tensor<4xi32>) -> tensor<*xi32> + %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + // CHECK: tosa.negate %arg0, {{.+}} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32> // CHECK: tosa.reverse %arg0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<4xi32> %5 = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<4xi32>) -> tensor