diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index e4f5d09064cd..27061002b029 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_TOSA_IR_TOSAOPS_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" @@ -53,34 +54,43 @@ class MulOperandsAndResultElementType : public TraitBase { public: static LogicalResult verifyTrait(Operation *op) { - auto resElemType = getElementTypeOrSelf(op->getResult(0)); - - // In cases of floating point type, op requires the same element - // type for all operands and result. - if (llvm::isa(resElemType)) - return impl::verifySameOperandsAndResultElementType(op); - + // Check we have a single result. + if (failed(impl::verifyOneResult(op))) + return failure(); + Type resElemType = getElementTypeOrSelf(op->getResult(0)); + + // Check we have lhs and rhs. + if (failed(impl::verifyAtLeastNOperands(op, 2))) + return failure(); + + Type lhsElemType = getElementTypeOrSelf(op->getOperand(0)); + Type rhsElemType = getElementTypeOrSelf(op->getOperand(1)); + + // Check that for i32 a shift has been explicitly provided. + if (lhsElemType.isInteger(32) && failed(impl::verifyNOperands(op, 3))) + return failure(); + + // Verify operands type match (ignoring the shift parameter which will + // always be i8). + if (lhsElemType != rhsElemType) + return op->emitOpError("requires the same element type for all operands"); + + // Though the spec requires the element type of result to be i32, a more + // relaxed way is provided at dialect level for easier cooperating with + // other dialects. if (auto resIntType = dyn_cast(resElemType)) { - IntegerType lhsIntType = - cast(getElementTypeOrSelf(op->getOperand(0))); - IntegerType rhsIntType = - cast(getElementTypeOrSelf(op->getOperand(1))); - if (lhsIntType != rhsIntType) - return op->emitOpError( - "requires the same element type for all operands"); - - // Though the spec requires the element type of result to be i32, a more - // relaxed way is provided at dialect level for easier cooperating with - // other dialects. + auto lhsIntType = cast(lhsElemType); if (lhsIntType.getWidth() > resIntType.getWidth()) return op->emitOpError("invalid data type size for operands or result"); - - return success(); + } else { + // In cases of floating point type or quant types, op requires the same + // element type for all operands and result (excluding shift). + if (resElemType != lhsElemType) + return op->emitOpError( + "requires the same element type for all operands and results"); } - // In cases of all other types, op requires the same element - // type for all operands and result. - return impl::verifySameOperandsAndResultElementType(op); + return llvm::success(); } }; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index fed20da33afd..4701301b7098 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -800,9 +800,11 @@ def MulOperandsAndResultElementType : //===----------------------------------------------------------------------===// // Operator: mul //===----------------------------------------------------------------------===// -def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [ +def Tosa_MulOp : Tosa_Op<"mul", [ + DeclareOpInterfaceMethods, Commutative, - MulOperandsAndResultElementType]> { + Pure]> { let summary = "Multiplication operator"; let description = [{ @@ -814,7 +816,8 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [ let arguments = (ins Tosa_Tensor:$input1, Tosa_Tensor:$input2, - I8Attr:$shift + // Apply right shift on i32_t input data only + Tosa_ScalarInt8Tensor:$shift ); let results = (outs @@ -823,6 +826,9 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [ let hasFolder = 1; let hasVerifier = 1; + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 98f15fec3a95..318f0ecaf989 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[ IsRankedTensorTypePred, CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>; +def AllDimensionsAreSizeOne : And<[ + IsRankedTensorTypePred, + CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>; + // AMD: removed HasNo0Dimensions constraint below to allow lowerings // in onnx-mlir like onnx.Split. class TosaTensorOf< @@ -111,6 +115,11 @@ class TosaTensorRankOf allowedTypes, list ranks> [HasAnyRankOfPred], !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">; +class TosaScalarTensorOf allowedTypes, list ranks> + : TosaRankedTensorOf, AllDimensionsAreSizeOne], + "tosa-conformant scalar tensor">; + //===----------------------------------------------------------------------===// // Tensor types //===----------------------------------------------------------------------===// @@ -139,8 +148,8 @@ class Tosa_TensorOfOrNone allowedTypes, string description = ""> : // Tensor types with constrained ranks. //===----------------------------------------------------------------------===// -// Rank-0 (scalar) tensor def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; +def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>; // We include unranked tensors as a supported type for all possible tosa // Tensors as unranked does not guarantee invalid. If unranked tensors exist diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index aaf743ab2800..8e6252597a91 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -100,43 +100,59 @@ static Value createLinalgBodyCalculationForElementwiseOp( } // tosa::MulOp - if (isa(op) && isa(elementTy)) - return rewriter.create(loc, resultTypes, args); - - if (isa(op) && isa(elementTy)) { - Value a = args[0]; - Value b = args[1]; - auto shift = - cast(op->getAttr("shift")).getValue().getSExtValue(); - if (shift > 0) { - auto shiftConst = - rewriter.create(loc, shift, /*bitwidth=*/8); - if (!a.getType().isInteger(32)) - a = rewriter.create(loc, rewriter.getI32Type(), a); - - if (!b.getType().isInteger(32)) - b = rewriter.create(loc, rewriter.getI32Type(), b); - - auto result = rewriter.create( - loc, rewriter.getI32Type(), a, b, shiftConst, - rewriter.getBoolAttr(false)); - - if (elementTy.isInteger(32)) - return result; - - return rewriter.create(loc, elementTy, result); + if (isa(op)) { + auto shift_val = cast(op).getShift(); + ElementsAttr shift_elem; + if (!shift_val.getImpl() || + !matchPattern(shift_val, m_Constant(&shift_elem))) { + (void)rewriter.notifyMatchFailure(op, "shift value of mul not found"); } - int aWidth = a.getType().getIntOrFloatBitWidth(); - int bWidth = b.getType().getIntOrFloatBitWidth(); - int cWidth = resultTypes[0].getIntOrFloatBitWidth(); + int32_t shift = shift_elem.getValues()[0].getInt(); - if (aWidth < cWidth) - a = rewriter.create(loc, resultTypes[0], a); - if (bWidth < cWidth) - b = rewriter.create(loc, resultTypes[0], b); + if (isa(elementTy)) { + if (shift != 0) { + (void)rewriter.notifyMatchFailure(op, + "Cannot have shift value for float"); + return nullptr; + } + return rewriter.create(loc, resultTypes, args[0], args[1]); + } + + if (isa(elementTy)) { + Value a = args[0]; + Value b = args[1]; + + if (shift > 0) { + auto shiftConst = + rewriter.create(loc, shift, /*bitwidth=*/8); + if (!a.getType().isInteger(32)) + a = rewriter.create(loc, rewriter.getI32Type(), a); + + if (!b.getType().isInteger(32)) + b = rewriter.create(loc, rewriter.getI32Type(), b); + + auto result = rewriter.create( + loc, rewriter.getI32Type(), a, b, shiftConst, + rewriter.getBoolAttr(false)); + + if (elementTy.isInteger(32)) + return result; - return rewriter.create(loc, resultTypes, a, b); + return rewriter.create(loc, elementTy, result); + } + + int aWidth = a.getType().getIntOrFloatBitWidth(); + int bWidth = b.getType().getIntOrFloatBitWidth(); + int cWidth = resultTypes[0].getIntOrFloatBitWidth(); + + if (aWidth < cWidth) + a = rewriter.create(loc, resultTypes[0], a); + if (bWidth < cWidth) + b = rewriter.create(loc, resultTypes[0], b); + + return rewriter.create(loc, resultTypes, a, b); + } } // tosa::NegateOp @@ -990,7 +1006,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, auto loc = operation->getLoc(); auto rank = cast(operation->getResultTypes().front()).getRank(); - auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank); + // For the mul op we need to avoid expanding the rank of the optional shift + // input. + auto operandsToExpand = + isa(operation) ? operands.take_front(2) : operands; + + auto expandedOperands = + expandInputRanks(rewriter, loc, operandsToExpand, rank); auto [targetShape, masterOperands] = computeTargetShape(rewriter, loc, indexPool, expandedOperands); auto broadcastOperands = broadcastDynamicDimensions( diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index ddc02993892b..07962ca0467e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1230,7 +1230,18 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); - const int64_t shift = llvm::isa(resultETy) ? getShift() : 0; + // Result right shift on i32_t data type only. For simplification, synthesize + // a zero shift for other data type. + int32_t shift = 0; + if (resultETy.isInteger(32)) { + ElementsAttr shift_elem; + if (getShift().getImpl()) { + if (!matchPattern(getShift(), m_Constant(&shift_elem))) + // cannot be folded when the shift value is unknown. + return {}; + shift = shift_elem.getValues()[0].getInt(); + } + } if (rhsTy == resultTy) { if (isSplatZero(resultETy, lhsAttr)) @@ -1245,7 +1256,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { return lhs; } - return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift()); + return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift); } OpFoldResult SubOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 74c3fc990920..1f741dd37b00 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -974,10 +974,116 @@ LogicalResult tosa::SliceOp::verify() { return success(); } +LogicalResult tosa::MulOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + // mul op's output shape only depend on input1 and input2, not on shift + ValueShapeRange twoInputs = operands.drop_back(); + llvm::SmallVector outShape; + if (resolveBroadcastShape(twoInputs, outShape).failed()) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + } else { + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + } + return success(); +} + LogicalResult tosa::MulOp::verify() { - Type elementTy = getInput1().getType().getElementType(); - if (isa(elementTy) && getShift() != 0) - return emitOpError() << "require shift to be 0 for float type"; + auto resElemType = getElementTypeOrSelf(getOutput()); + + // Verify if the element type among operands and result match tosa + // specification. + if (auto resIntType = dyn_cast(resElemType)) { + IntegerType lhsIntType = + cast(getElementTypeOrSelf(getInput1())); + IntegerType rhsIntType = + cast(getElementTypeOrSelf(getInput2())); + if (lhsIntType != rhsIntType) + return emitOpError("requires the same element type for all operands"); + + // Though the spec requires the element type of result to be i32, a more + // relaxed way is provided at dialect level for easier cooperating with + // other dialects. + if (lhsIntType.getWidth() > resIntType.getWidth()) + return emitOpError("invalid data type size for operands or result"); + + } else { + // For other supported type, the spec requires requires the same element + // type for all operands (excludes `shift` operand) and results. + for (int i = 0; i < 2; ++i) { + if (getElementTypeOrSelf(getOperand(i)) != resElemType) + return emitOpError( + "requires the same element type for all operands and results"); + } + + // verify shift has value 0 for non-integer types + ElementsAttr shift_elem; + if (matchPattern(getShift(), m_Constant(&shift_elem))) { + int32_t shift = shift_elem.getValues()[0].getInt(); + if (shift != 0) { + return emitOpError() << "require shift to be 0 for float type"; + } + } + } + + // Verify the op has same ranks for all main operands (excludes extra operands + // such as shift of mul op, so this is the only difference with the built-in + // `SameOperandsAndResultRank` trait) and results types, if known. + + // delegate function that returns true if type is a shaped type with known + // rank + auto hasRank = [](const Type type) { + if (auto shaped_type = dyn_cast(type)) + return shaped_type.hasRank(); + + return false; + }; + + auto rankedOperandTypes = + llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank)); + + auto rankedResultTypes = + llvm::make_filter_range(getOperation()->getResultTypes(), hasRank); + + // If all operands and results are unranked, then no further verification. + if (rankedOperandTypes.empty() && rankedResultTypes.empty()) + return success(); + + // delegate function that returns rank of shaped type with known rank + auto getRank = [](const Type type) { + return cast(type).getRank(); + }; + + auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin()) + : getRank(*rankedResultTypes.begin()); + + for (size_t i = 0; i < 2; ++i) { + if (rank != getRank(rankedOperandTypes[i])) { + return emitOpError("operands don't have matching ranks"); + } + } + + for (const auto type : rankedResultTypes) { + if (rank != getRank(type)) { + return emitOpError("result type has different rank than operands"); + } + } + + // check for broadcast compatible shapes in first two operands (ignoring + // shift) + + // delegate function that returns shape of shaped type + auto getShape = [](const Type type) { + return mlir::cast(type).getShape(); + }; + SmallVector resultShape; + if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]), + getShape(rankedOperandTypes[1]), + resultShape)) { + return emitOpError("operands don't have broadcast-compatible shapes"); + } return success(); } @@ -1636,7 +1742,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) NARY_SHAPE_INFER(tosa::LogicalXorOp) NARY_SHAPE_INFER(tosa::MaximumOp) NARY_SHAPE_INFER(tosa::MinimumOp) -NARY_SHAPE_INFER(tosa::MulOp) NARY_SHAPE_INFER(tosa::NegateOp) NARY_SHAPE_INFER(tosa::PowOp) NARY_SHAPE_INFER(tosa::ReciprocalOp) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 45f4419875b4..181aff3a9ce0 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Pass/Pass.h" using namespace mlir; @@ -131,9 +132,15 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { return failure(); } + auto shiftElementType = IntegerType::get(rewriter.getContext(), 8); + auto shiftType = RankedTensorType::get({1}, shiftElementType); + auto shiftZeroAttr = DenseElementsAttr::get( + shiftType, rewriter.getIntegerAttr(shiftElementType, 0)); + Value constZero = + rewriter.create(op.getLoc(), shiftType, shiftZeroAttr); Value mulValue = rewriter .create(op.getLoc(), mulShapeType, input, - weight, /*shift=*/0) + weight, constZero) .getResult(); // Reshape output to [N, H, W, C * M]. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index 962c89902136..e17f55761c55 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -355,7 +355,8 @@ struct TosaFoldConstantBase : public OpRewritePattern { DenseElementsAttr valuesSecond) const { if (!foldSplatOrSingleUseOnly) return true; - assert(binaryOp->getNumOperands() == 2); + assert(binaryOp->getNumOperands() >= 2 && + "binary folding expects at least two operands"); auto firstOp = binaryOp->getOperand(0); auto secondOp = binaryOp->getOperand(1); @@ -750,10 +751,19 @@ struct TosaFoldConstantMul DenseElementsAttr computeInteger(DenseElementsAttr lhsValues, DenseElementsAttr rhsValues, PatternRewriter &rewriter, MulOp op) const { - if (op.getShift() > 0) { - (void)rewriter.notifyMatchFailure( - op, "Non-zero shift folding is currently not implemented."); - return {}; + if (Value shiftVal = op.getShift()) { + ElementsAttr shiftAttr; + if (!matchPattern(shiftVal, m_Constant(&shiftAttr))) { + (void)rewriter.notifyMatchFailure( + op, "shift must be a constant for folding."); + return {}; + } + if (llvm::any_of(shiftAttr.getValues(), + [](IntegerAttr attr) { return attr.getInt() != 0; })) { + (void)rewriter.notifyMatchFailure( + op, "Non-zero shift folding is currently not implemented."); + return {}; + } } auto resultElementWidth = diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp index 2a990eed3f68..79afc75fd6c8 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -113,7 +113,7 @@ struct ConvertTosaOp : public OpRewritePattern { Value input1 = tosaBinaryOp.getInput1(); Value input2 = tosaBinaryOp.getInput2(); - int32_t shift = tosaBinaryOp.getShift(); + Value shift = tosaBinaryOp.getShift(); Value output = tosaBinaryOp.getResult(); auto outputType = dyn_cast(output.getType()); if (!outputType) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp index 539f3e833b12..4c312ffd124e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp @@ -187,7 +187,7 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input, // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension // 0. If not in place, something is very wrong. - if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) { + if (rank <= 0 || oldType.getNumElements() <= 0) { signalPassFailure(); return std::nullopt; } @@ -281,13 +281,19 @@ bool TosaReduceTransposes::collectFanIn(Operation *op, if (!llvm::isa(op) && !llvm::isa(op) && !llvm::isa(op)) { - if (!op->hasTrait()) + if (!llvm::isa(op) && + !op->hasTrait()) return false; - for (Value operand : op->getOperands()) + for (Value operand : op->getOperands()) { // If this is a problem in future, think about alternatives to recursion. + if (llvm::isa(op) && operand == op->getOperand(2)) { + // do not recurse into MulOp's shift operand + continue; + } if (!collectFanIn(operand.getDefiningOp(), collected)) return false; + } } // Insert in topological order. @@ -316,7 +322,8 @@ std::optional TosaReduceTransposes::buildMappedToValue( Operation *op, const DenseMap &valuesMap, IRRewriter &rewriter, ArrayRef hoistedPerms) { if (op->getNumResults() != 1 || - !op->hasTrait()) + (!llvm::isa(op) && + !op->hasTrait())) return std::nullopt; auto resultType = llvm::cast(op->getResult(0).getType()); @@ -324,6 +331,9 @@ std::optional TosaReduceTransposes::buildMappedToValue( for (Value v : op->getOperands()) { if (valuesMap.contains(v)) { operands.push_back(valuesMap.at(v)); + } else if (llvm::isa(op) && v == op->getOperand(2)) { + // special case for MulOp's shift operand + operands.push_back(v); } else { return std::nullopt; } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index aaa4c0430fba..98ce3ade5978 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -492,7 +492,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () { // CHECK: linalg.generic // CHECK: arith.mulf - %4 = tosa.mul %0, %1 {shift = 0 : i8} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %4 = tosa.mul %0, %1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: arith.negf @@ -658,7 +659,8 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () { // CHECK: arith.extsi // CHECK: arith.extsi // CHECK: arith.muli - %0 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg0, %shift : (tensor<1xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1xi32> return } @@ -691,12 +693,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns // CHECK: linalg.generic // CHECK: arith.muli - %2 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %shift1 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = tosa.mul %arg0, %arg0, %shift1 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: arith.constant 2 // CHECK: apply_scale - %3 = tosa.mul %arg0, %arg0 {shift = 2 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %shift2 = "tosa.const"() <{value = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = tosa.mul %arg0, %arg0, %shift2: (tensor<1xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: arith.divsi diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 66c5904004c1..f8a09409226f 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -543,8 +543,9 @@ func.func @no_pad_pad_optimization_different_value(%arg0: tensor<1x478x640x32xbf func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: return %arg0 // CHECK-NOT: tosa.mul + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32> return %1 : tensor<2x3xf32> } @@ -555,7 +556,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: return %arg0 // CHECK-NOT: tosa.mul %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> - %1 = tosa.mul %ones, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32> return %1 : tensor<2x3xf32> } @@ -565,8 +567,22 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK: return %arg0 // CHECK-NOT: tosa.mul + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32> - %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32> + return %1 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: @mul_one_int_and_shift +func.func @mul_one_int_and_shift(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}> + // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<31> : tensor<1xi8>}> + // CHECK: %[[VAL_3:.*]] = tosa.mul %arg0, %[[VAL_1]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) + %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32> + %shift = "tosa.const"() <{value = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32> return %1 : tensor<2x3xi32> } @@ -577,11 +593,12 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>} // CHECK-NOT: tosa.mul %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> - %1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32> // CHECK-NOT: tosa.mul // CHECK: return %[[ZERO]], %[[ZERO]] - %2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32> return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32> } @@ -590,8 +607,9 @@ func.func @mul_zero_broadcast_dynamic_result(%arg0: tensor) -> (tensor< // CHECK: tosa.mul // CHECK: tosa.mul %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> - %1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor - %2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %arg0, %zeros, %shift : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor + %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor, tensor<1xi8>) -> tensor return %1, %2 : tensor, tensor } @@ -1437,7 +1455,8 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform : tensor<1xi8>} : () -> tensor<1x!quant.uniform> %1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform> - %2 = tosa.mul %0, %1 { shift = 0 : i8} : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = tosa.mul %0, %1, %shift : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1xi8>) -> tensor<1x!quant.uniform> return %2 : tensor<1x!quant.uniform> } @@ -1563,7 +1582,8 @@ func.func @canonicalize_select_lrelu_zero_pattern(%arg0: tensor<13x21x3xf32>) -> // CHECK: %[[VAL_1:.*]] = tosa.clamp %arg{{.*}} {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.000000e+00 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: return %[[VAL_1]] : tensor<13x21x3xf32> %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}>: () -> tensor<1x1x1xf32> - %1 = tosa.mul %arg0, %0 {shift = 0 : i8}: (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %arg0, %0, %shift : (tensor<13x21x3xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xi1> %3 = tosa.select %2, %arg0, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> return %3 : tensor<13x21x3xf32> @@ -1624,4 +1644,3 @@ func.func @canonicalize_select_to_clamp_i8_and_i64_pat2(%arg0: tensor<13x21x3xi8 %3 = tosa.select %2, %1, %arg1: ( tensor<13x21x3xi1>, tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> return %3 : tensor<13x21x3xi64> } - diff --git a/mlir/test/Dialect/Tosa/constant-mul-opt.mlir b/mlir/test/Dialect/Tosa/constant-mul-opt.mlir index d7a1f0379bab..21e6849a9fd3 100644 --- a/mlir/test/Dialect/Tosa/constant-mul-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-mul-opt.mlir @@ -1,3 +1,5 @@ +// Modifications (c) Copyright 2023-2025 Advanced Micro Devices, Inc. or its +// affiliates // RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s // Float multiplications @@ -15,7 +17,8 @@ func.func @mul_fold_float() -> tensor<4xf16> { dense<[-132.7, -3.0, -0.0, 5.0]> : tensor<4xf16> } : () -> tensor<4xf16> - %2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<4xf16>, tensor<4xf16>, tensor<1xi8>) -> tensor<4xf16> return %2 : tensor<4xf16> } @@ -32,7 +35,8 @@ func.func @mul_fold_float_infinity_nan() -> tensor<7xf32> { dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000, 0.0]> : tensor<7xf32> } : () -> tensor<7xf32> - %2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<7xf32>, tensor<7xf32>) -> tensor<7xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<7xf32>, tensor<7xf32>, tensor<1xi8>) -> tensor<7xf32> return %2 : tensor<7xf32> } @@ -49,7 +53,8 @@ func.func @add_fold_float_overflow() -> tensor<2xf32> { dense<[2.1e+38, 1.1e+38]> : tensor<2xf32> } : () -> tensor<2xf32> - %2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<2xf32>, tensor<2xf32>, tensor<1xi8>) -> tensor<2xf32> return %2 : tensor<2xf32> } @@ -69,7 +74,8 @@ func.func @mul_fold_int() -> tensor<4xi32> { dense<[-132, -3, 0, 5]> : tensor<4xi32> } : () -> tensor<4xi32> - %2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> return %2 : tensor<4xi32> } @@ -87,10 +93,12 @@ func.func @mul_fold_i8() -> tensor<4xi32> { tensor<4xi8> } : () -> tensor<4xi8> // TODO: This is wrongly rejected as illegal, see https://reviews.llvm.org/D150472#4484478 - // %2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi32> + // %zero_shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // %2 = "tosa.mul"(%0, %1, %zero_shift) : (tensor<4xi8>, tensor<4xi8>, tensor<1xi8>) -> tensor<4xi32> %a = "tosa.cast"(%0) : (tensor<4xi8>) -> tensor<4xi32> %b = "tosa.cast"(%1) : (tensor<4xi8>) -> tensor<4xi32> - %2 = "tosa.mul"(%a, %b) {shift = 0 : i8} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%a, %b, %shift) : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> return %2 : tensor<4xi32> } @@ -110,8 +118,9 @@ func.func @mul_fold_int_overflow() -> tensor<4xi32> { dense<[1, 10, 1, 30]> : tensor<4xi32> } : () -> tensor<4xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> // expected-warning@below {{Multiplication did overflow. The results are unspecified.}} - %2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> return %2 : tensor<4xi32> } @@ -127,7 +136,8 @@ func.func @mul_fold_equal_args() -> tensor<3xi32> { dense<[-17, 4, 0]> : tensor<3xi32> } : () -> tensor<3xi32> - %2 = "tosa.mul"(%0, %0) {shift = 0 : i8} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %0, %shift) : (tensor<3xi32>, tensor<3xi32>, tensor<1xi8>) -> tensor<3xi32> return %2 : tensor<3xi32> } @@ -147,7 +157,8 @@ func.func @mul_fold_int_broadcast_simple() -> tensor<3xi32> { dense<-12> : tensor<1xi32> } : () -> tensor<1xi32> - %2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<3xi32> return %2 : tensor<3xi32> } @@ -167,15 +178,17 @@ func.func @mul_fold_int_broadcast_complex() -> tensor<3x3xi32> { dense<[[-12, 7, 4]]> : tensor<1x3xi32> } : () -> tensor<1x3xi32> - %2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<3x1xi32>, tensor<1x3xi32>) -> tensor<3x3xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<3x1xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<3x3xi32> return %2 : tensor<3x3xi32> } // CHECK-LABEL: @mul_fold_int_non_zero_shift func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> { - // CHECK: [[FIRST:]] ={{.*}}tosa.const - // CHECK-NEXT: [[SECOND:]] ={{.*}}tosa.const - // CHECK-NEXT: [[MUL:]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]] + // CHECK: [[FIRST:%.*]] ={{.*}}tosa.const + // CHECK-NEXT: [[SECOND:%.*]] ={{.*}}tosa.const + // CHECK-NEXT: [[SHIFT:%.*]] ={{.*}}tosa.const + // CHECK-NEXT: [[MUL:%.*]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]], [[SHIFT]] // CHECK-NEXT: return [[MUL]] %0 = "tosa.const"() {value = dense<[-17, 4, 0, 0]> : @@ -185,6 +198,7 @@ func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> { dense<[-132, -3, 0, 5]> : tensor<4xi32> } : () -> tensor<4xi32> - %2 = "tosa.mul"(%0, %1) {shift = 1 : i8} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %shift = "tosa.const"() <{value = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> return %2 : tensor<4xi32> } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index aff1dd642311..13baf99fe7b7 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -272,7 +272,8 @@ func.func @fold_div_splat_i32() -> tensor { func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> - %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %[[ZERO]] return %mul : tensor } @@ -283,7 +284,8 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { func.func @fold_mul_zero_lhs_f32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> - %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %[[ZERO]] return %mul : tensor } @@ -293,8 +295,9 @@ func.func @fold_mul_zero_lhs_f32(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_mul_zero_rhs_i32 func.func @fold_mul_zero_rhs_i32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> - %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor, tensor) -> tensor + %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %[[ZERO]] return %mul : tensor } @@ -304,8 +307,9 @@ func.func @fold_mul_zero_rhs_i32(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_mul_zero_lhs_i32 func.func @fold_mul_zero_lhs_i32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> - %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor, tensor) -> tensor + %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %[[ZERO]] return %mul : tensor } @@ -315,7 +319,8 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_mul_one_rhs_f32 func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { %one = "tosa.const"() {value = dense<1.0> : tensor} : () -> tensor - %mul = tosa.mul %arg0, %one {shift = 0 : i8} : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %arg0 return %mul : tensor } @@ -325,7 +330,8 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_mul_one_lhs_f32 func.func @fold_mul_one_lhs_f32(%arg0: tensor) -> tensor { %one = "tosa.const"() {value = dense<1.0> : tensor} : () -> tensor - %mul = tosa.mul %one, %arg0 {shift = 0 : i8} : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %arg0 return %mul : tensor } @@ -335,7 +341,8 @@ func.func @fold_mul_one_lhs_f32(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_mul_one_rhs_i32 func.func @fold_mul_one_rhs_i32(%arg0: tensor) -> tensor { %one = "tosa.const"() {value = dense<64> : tensor} : () -> tensor - %mul = tosa.mul %arg0, %one {shift = 6 : i8} : (tensor, tensor) -> tensor + %shift = "tosa.const"() {value = dense<6> : tensor<1xi8>} : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %arg0 return %mul : tensor } @@ -345,7 +352,8 @@ func.func @fold_mul_one_rhs_i32(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_mul_one_lhs_i32 func.func @fold_mul_one_lhs_i32(%arg0: tensor) -> tensor { %one = "tosa.const"() {value = dense<64> : tensor} : () -> tensor - %mul = tosa.mul %one, %arg0 {shift = 6 : i8} : (tensor, tensor) -> tensor + %shift = "tosa.const"() {value = dense<6> : tensor<1xi8>} : () -> tensor<1xi8> + %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %arg0 return %mul : tensor } @@ -356,7 +364,8 @@ func.func @fold_mul_one_lhs_i32(%arg0: tensor) -> tensor { func.func @fold_mul_splat_i8() -> tensor<10xi32> { %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8> %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8> - %mul = tosa.mul %one, %two {shift = 3 : i8} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32> + %shift = "tosa.const"() {value = dense<3> : tensor<1xi8>} : () -> tensor<1xi8> + %mul = tosa.mul %one, %two, %shift : (tensor<10xi8>, tensor<10xi8>, tensor<1xi8>) -> tensor<10xi32> // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>} // CHECK: return %[[THREE]] return %mul : tensor<10xi32> @@ -368,7 +377,8 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> { func.func @fold_mul_splat_f32() -> tensor<10xf32> { %one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32> %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %mul = tosa.mul %one, %two {shift = 0 : i8} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32> // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>} // CHECK: return %[[THREE]] return %mul : tensor<10xf32> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index b1ca64447051..ba9bdd34e2b3 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -225,7 +225,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor) -> t // ----- func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) { - %padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> + %padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}} %1 = tosa.pad %arg0, %padding : (tensor<13x21xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32> return @@ -253,7 +253,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso // ----- func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> + %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}} %1 = tosa.pad %arg0, %0 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32> return %1 : tensor<13x21x3xf32> @@ -772,15 +772,35 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, // ----- +// CHECK-LABEL: test_mul_type_mismatch +func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf16>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- + // CHECK-LABEL: test_mul_invalid_shift func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}} - %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } // ----- +// CHECK-LABEL: test_mul_missing_shift +func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { + // expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}} + %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- + // CHECK-LABEL: test_unsupported_int64_data_type func.func @test_unsupported_int64_data_type(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> { // expected-error@+1 {{'tosa.argmax' op is not profile-aligned: element type 'i64' is not legal}} @@ -1086,3 +1106,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1: %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> return %0 : tensor<1x13x21x3xf32> } + +// ----- +// CHECK-LABEL: test_mul_non_scalar_shift_2d +func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: test_mul_non_scalar_shift_1d +func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8> + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: test_mul_non_broadcast +func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 19b93d761185..4165786e7892 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -327,17 +327,35 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te return %0 : tensor<13x21x3xf32> } +// ----- +// CHECK-LABEL: test_mul_scalar_with_unranked_output +func.func @test_mul_scalar_with_unranked_output(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor, tensor, tensor<1xi8>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // ----- // CHECK-LABEL: mul func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } +// ----- +// CHECK-LABEL: i32_mul +func.func @test_i32_mul(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + // ----- // CHECK-LABEL: mul func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> { - %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i8 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi16>, tensor<13x1x3xi16>, tensor<1xi8>) -> tensor<13x21x3xi16> return %0 : tensor<13x21x3xi16> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir index bbcc206e1490..5f36dd3b3d13 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -34,7 +34,7 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor< // CHECK: %[[sIn:.+]] = tosa.sub %[[cIn]], %[[iZp]] // CHECK: %[[sWe:.+]] = tosa.sub %[[cWe]], %[[wZp]] // CHECK: %[[resWe:.+]] = tosa.reshape %[[sWe]] {new_shape = array} - // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]] {shift = 0 : i8} + // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]] // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array} // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]] @@ -51,7 +51,7 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t // CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array} // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor) -> tensor<4x12x12x2x1xf32> // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i8} + // CHECK: %[[mul:.+]] = tosa.mul %[[padded]], %[[reArg1]] // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array} // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[add:.+]] = tosa.add %[[reOut]], %[[reArg2]] diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index b7af09bdfad6..c21ddb69a842 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -114,23 +114,24 @@ func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %3 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<4xf32> + %4 = tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<*xf32> // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %5 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %6 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %7 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %8 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %9 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> return } @@ -148,23 +149,24 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32 // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %3 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<4xf32> + %4 = tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<*xf32> // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %5 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %6 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %7 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %8 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %9 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> return } @@ -206,14 +208,15 @@ func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<1xi32>) -> () { // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> %10 = tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %11 = tosa.mul %arg0, %arg1 { shift = 0 : i8 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: tosa.mul %arg0, %arg1, %{{.*}} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<4xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %11 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<*xi32> // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + %13 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %13 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + %14 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> return } @@ -1327,7 +1330,7 @@ func.func @test_non_tosa_consumer_shape(%arg0: tensor<4x4xf32>) -> !shape.shape // ----- -// CHECK-LABEL: test_non_tosa_consumer_shape2 +// CHECK-LABEL: test_non_tosa_consumer_shape func.func @test_non_tosa_consumer_shape2(%arg0: tensor<4x4xf32>) -> tensor { // CHECK: tosa.log %arg0 : (tensor<4x4xf32>) -> tensor<4x4xf32> %0 = tosa.log %arg0 : (tensor<4x4xf32>) -> tensor<*xf32> diff --git a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir index 3f0d7544083a..6d6f68bd8287 100644 --- a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir @@ -122,13 +122,15 @@ func.func @test_torch_conv2d_with_elementwise_in_between(%arg0: tensor<3x3x10x10 // ----- // CHECK-LABEL: @test_mulop_conversion -// CHECK-NEXT: %[[RES:.*]] = tosa.mul %arg0, %arg1 +// CHECK-NEXT: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-NEXT: %[[RES:.*]] = tosa.mul %arg0, %arg1, %[[SHIFT]] // CHECK-NEXT: return %[[RES]] func.func @test_mulop_conversion(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> { %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32> %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32> - %mul = tosa.mul %transpose0, %transpose1 {shift = 0 : i8} : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %transpose0, %transpose1, %shift : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>, tensor<1xi8>) -> tensor<1x3x4x2xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> %result = tosa.transpose %mul, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32> return %result : tensor<1x2x3x4xi32> @@ -185,6 +187,7 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x // CHECK-LABEL: @test_resnet18_common_case // COM: note that %74 is now represented by %arg2 +// CHECK-DAG: %[[CONST0:.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> // CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> // CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32> @@ -195,15 +198,16 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x // CHECK-DAG: %[[VAL_9:.*]] = tosa.reshape %arg0 {new_shape = array} : (tensor<64xf32>) -> tensor<1x1x1x64xf32> // CHECK-DAG: %[[VAL_10:.*]] = tosa.sub %arg2, %[[VAL_9]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<64xf32>) -> tensor<1x1x1x64xf32> -// CHECK-DAG: %[[VAL_12:.*]] = tosa.mul %[[VAL_10]], %[[VAL_11]] {shift = 0 : i8} : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_12:.*]] = tosa.mul %[[VAL_10]], %[[VAL_11]], %[[CONST0]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>, tensor<1xi8>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<64xf32>) -> tensor<1x1x1x64xf32> -// CHECK-DAG: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]], %[[CONST0]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>, tensor<1xi8>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_15:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<64xf32>) -> tensor<1x1x1x64xf32> // CHECK-DAG: %[[VAL_16:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_17:.*]] = tosa.clamp %[[VAL_16]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> // CHECK: return %[[VAL_17]] : tensor<1x112x112x64xf32> func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> { + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> %59 = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> %60 = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> %63 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> @@ -217,9 +221,9 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32 %79 = tosa.reshape %arg0 {new_shape = array} : (tensor<64xf32>) -> tensor<1x64x1x1xf32> %80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> %81 = tosa.reshape %78 {new_shape = array} : (tensor<64xf32>) -> tensor<1x64x1x1xf32> - %82 = tosa.mul %80, %81 {shift = 0 : i8} : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> + %82 = tosa.mul %80, %81, %shift : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>, tensor<1xi8>) -> tensor<1x64x112x112xf32> %83 = tosa.reshape %60 {new_shape = array} : (tensor<64xf32>) -> tensor<1x64x1x1xf32> - %84 = tosa.mul %82, %83 {shift = 0 : i8} : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> + %84 = tosa.mul %82, %83, %shift : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>, tensor<1xi8>) -> tensor<1x64x112x112xf32> %85 = tosa.reshape %59 {new_shape = array} : (tensor<64xf32>) -> tensor<1x64x1x1xf32> %86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> %87 = tosa.clamp %86 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>