From 5f385cca5737ab493ea01f06c2c85f6776b8d904 Mon Sep 17 00:00:00 2001 From: chaitany Date: Wed, 10 Jan 2024 11:44:55 +0530 Subject: [PATCH 01/12] feat: adding verifiers for some tosa operators --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 18 +- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 317 +++++++++++++++---- mlir/test/Dialect/Tosa/invalid.mlir | 244 ++++++++++++++ mlir/test/Dialect/Tosa/ops.mlir | 16 +- 4 files changed, 527 insertions(+), 68 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 3331ca4cb8643..f51d496dcfa5c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -305,6 +305,7 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [ ); let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -498,7 +499,8 @@ def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> { Tosa_Tensor:$output ); - let hasFolder = 1; + let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -783,6 +785,9 @@ def Tosa_MulOp : Tosa_ElemWiseBinaryOp<"mul", [Commutative]> { ); let hasFolder = 1; + + let hasVerifier = 1; + } //===----------------------------------------------------------------------===// @@ -1125,6 +1130,7 @@ def Tosa_SelectOp : Tosa_Op<"select", [ ); let hasCanonicalizeMethod = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1209,6 +1215,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ ); let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1591,8 +1598,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [ // Operator: transpose //===----------------------------------------------------------------------===// def Tosa_TransposeOp : Tosa_Op<"transpose", [ - DeclareOpInterfaceMethods, + InferTensorType, Pure]> { let summary = "Transpose operator"; @@ -1611,11 +1617,13 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [ let extraClassDeclaration = [{ LogicalResult getConstantPerms(llvm::SmallVector &perms); + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; let hasCanonicalizer = 1; - let hasFolder = 1; -} + let hasFolder = 1;} //===----------------------------------------------------------------------===// // TOSA Spec Section 2.10 diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 48dc95b3bed49..806df575cdea4 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -100,7 +100,8 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// -template static LogicalResult verifyConvOp(T op) { +template +static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast(op.getInput().getType()); auto weightType = llvm::dyn_cast(op.getWeight().getType()); @@ -140,10 +141,11 @@ template static LogicalResult verifyConvOp(T op) { return success(); } - -LogicalResult tosa::AvgPool2dOp::verify() { - auto inputETy = llvm::cast(getInput().getType()).getElementType(); - auto resultETy = llvm::cast(getType()).getElementType(); +template +static LogicalResult verifyPoolOp(T op) { + auto inputETy = + llvm::cast(op.getInput().getType()).getElementType(); + auto resultETy = llvm::cast(op.getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputETy)) @@ -153,17 +155,78 @@ LogicalResult tosa::AvgPool2dOp::verify() { llvm::dyn_cast(resultETy)) resultETy = quantType.getStorageType(); - auto accType = getAccType(); - if (llvm::isa(inputETy) && !accType.isInteger(32)) - return emitOpError("accumulator type for integer tensor is not i32"); - - if ((inputETy.isBF16() || inputETy.isF16()) && - !(accType.isF16() || accType.isF32())) - return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32"); + // [kernel_y, kernel_x] <-> [0,1] + auto kernel = op.getKernel(); + // [stride_y, stride_x] + auto stride = op.getStride(); + // [pad_top, pad_bottom, pad_left, pad_right] + auto pad = op.getPad(); + // ERROR_IF(kernel_y < 1 || kernel_x < 1); // kernel size must be >= 1 + if (kernel[0] < 1 || kernel[1] < 1) { + return op.emitOpError("kernel should be greater than one."); + } + // ERROR_IF(stride_y < 1 || stride_x < 1); + if (stride[0] < 0 || stride[1] < 0) { + return op.emitOpError("stride should be greater than one."); + } + // ERROR_IF(pad_top < 0 || pad_bottom < 0 || pad_left < 0 || pad_right < 0); + if (pad[0] < 0 || pad[1] < 0 || pad[2] < 0 || pad[3] < 0) { + return op.emitOpError("pad should be positive."); + } + // Padding must be less than kernel size to avoid + // a divide-by-zero. + /* + ERROR_IF(pad_right >= kernel_x || pad_left >= kernel_x); + ERROR_IF(pad_top >= kernel_y || pad_bottom >= kernel_y); + */ + + if (pad[3] >= kernel[1] || pad[2] >= kernel[1] || pad[0] >= kernel[0] || + pad[1] >= kernel[0]) { + return op.emitOpError("pad must be less than kernel size."); + } + + //[N,IH,IW,C] + auto inputShapeType = llvm::cast(op.getInput().getType()); + //[N,OH,OW,C] + auto outputShapeType = llvm::cast(op.getOutput().getType()); + if (inputShapeType.hasStaticShape() && outputShapeType.hasStaticShape()) { + auto inputShape = inputShapeType.getShape(); + auto outputShape = outputShapeType.getShape(); + auto inputHeight = inputShape[1]; + auto inputWidth = inputShape[2]; + auto outputHeight = outputShape[1]; + auto outputWidth = outputShape[2]; + // IH + pad_top + pad_bottom - kernel_y + auto height = inputHeight + pad[0] + pad[1] - kernel[0]; + // IW + pad_left + pad_right - kernel_x + auto width = inputWidth + pad[2] + pad[3] - kernel[1]; + // idiv_check(IH + pad_top + pad_bottom - kernel_y, stride_y) + if (height % stride[0] != 0) { + return op.emitOpError("vertical stride is not in correct multiple."); + } + // idiv_check(IW + pad_left + pad_right - kernel_x, stride_x) + if (width % stride[1] != 0) { + return op.emitOpError("horizontal stride is not in correct multiple."); + } + /* + ERROR_IF(OH != idiv_check(IH + pad_top + pad_bottom - kernel_y, stride_y) + + 1); + */ - if (inputETy.isF32() && !accType.isF32()) - return emitOpError("accumulator type for f32 tensor is not f32"); + if ((outputHeight != (height / stride[0]) + 1)) { + return op.emitOpError("output height is not correct, should be ") + << (height / stride[0]) + 1 << "."; + } + /* + ERROR_IF(OW != idiv_check(IW + pad_left + pad_right - kernel_x, stride_x) + + 1); + */ + if (outputWidth != (width / stride[1]) + 1) { + return op.emitOpError("output width is not correct, should be ") + << (width / stride[1]) + 1 << "."; + } + } if (inputETy.isF32() && resultETy.isF32()) return success(); if (inputETy.isInteger(8) && resultETy.isInteger(8)) @@ -171,7 +234,51 @@ LogicalResult tosa::AvgPool2dOp::verify() { if (inputETy.isInteger(16) && resultETy.isInteger(16)) return success(); - return emitOpError("input/output element types are incompatible."); + return op.emitOpError("input/output element types are incompatible."); +} +// LogicalResult tosa::AddOp::verify() { + +// auto input1ShapedType = llvm::cast(getInput1().getType()); +// auto input2ShapedType = llvm::cast(getInput2().getType()); +// auto resultShapedType = llvm::cast(getType()); + +// if (input1ShapedType.hasStaticShape() && input2ShapedType.hasStaticShape() +// && +// resultShapedType.hasStaticShape()) { +// if (input1ShapedType.getRank() != input2ShapedType.getRank()) { +// return emitOpError("input tensors must be of equal rank."); +// } +// return success(); +// } +// return success(); +// } + +LogicalResult tosa::MaxPool2dOp::verify() { return verifyPoolOp(*this); } +LogicalResult tosa::AvgPool2dOp::verify() { + auto inputETy = llvm::cast(getInput().getType()).getElementType(); + auto resultETy = llvm::cast(getType()).getElementType(); + + if (auto quantType = + llvm::dyn_cast(inputETy)) + inputETy = quantType.getStorageType(); + + if (auto quantType = + llvm::dyn_cast(resultETy)) + resultETy = quantType.getStorageType(); + + auto accType = getAccType(); + if (llvm::isa(inputETy) && !accType.isInteger(32)) + return emitOpError("accumulator type for integer tensor is not i32"); + auto result = verifyPoolOp(*this); + if (result.succeeded()) { + if ((inputETy.isF16()) && !(accType.isF16() || accType.isF32())) + return emitOpError("accumulator type for f16 tensor is not f16/f32"); + if ((inputETy.isBF16()) && !(accType.isF32())) + return emitOpError("accumulator type for bf16 tensor is not f32"); + if (inputETy.isF32() && !accType.isF32()) + return emitOpError("accumulator type for f32 tensor is not f32"); + } + return result; } //===----------------------------------------------------------------------===// @@ -202,7 +309,8 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, } } -/// Handles tosa.transpose_conv2d which has outpad and output shape attributes. +/// Handles tosa.transpose_conv2d which has outpad and output shape +/// attributes. static void buildTransConvOpWithQuantInfo( OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, @@ -239,9 +347,9 @@ static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, } } -/// The tosa.matmul op is also intended to be generated where a fully_connected -/// op must be constructed where the weight is not a constant. In this case, -/// the fully_connected op must be expressed using matmul. +/// The tosa.matmul op is also intended to be generated where a +/// fully_connected op must be constructed where the weight is not a constant. +/// In this case, the fully_connected op must be expressed using matmul. /// TODO: Add link to the leglization document explaining this. static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, @@ -276,9 +384,9 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder, } } -/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr -/// but avg_pool operator has its own builder as it has additional parameters -/// not part of the unary ops. +/// Both the tosa.avg_pool2d and unary ops use the same +/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it +/// has additional parameters not part of the unary ops. static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, @@ -345,8 +453,8 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, for (int i = 0, e = operands.size(); i != e; ++i) { auto shape = operands.getShape(i); if (!shape.hasRank()) { - // TODO(jennik): Update function to have better case handling for invalid - // operands and for ranked tensors. + // TODO(jennik): Update function to have better case handling for + // invalid operands and for ranked tensors. return failure(); } outRank = std::max(outRank, shape.getRank()); @@ -601,8 +709,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( return success(); } - // If the input rank is unknown we can info the output rank using the padding - // shape's first dim. + // If the input rank is unknown we can info the output rank using the + // padding shape's first dim. if (!inputShape.hasRank()) { if (paddingShape.isDynamicDim(0)) { inferredReturnShapes.push_back(ShapedTypeComponents()); @@ -767,18 +875,18 @@ mlir::LogicalResult tosa::ReshapeOp::verify() { } if ((int64_t)getNewShape().size() != outputType.getRank()) { - return emitOpError() << "rank of newShape (" << getNewShape().size() - << ") and output (" - << outputType.getRank() + return emitOpError() << "rank of newShape (" << getNewShape().size() + << ") and output (" << outputType.getRank() << ") must match"; } - for (int64_t dim=0; dim < outputType.getRank(); ++dim) { - if (getNewShape()[dim] != -1 && getNewShape()[dim] != outputType.getShape()[dim]) { - return emitOpError() << "newShape attribute (" << getNewShape()[dim] - << ") does not match output type (" - << outputType.getShape()[dim] - << ") in dimension " << dim; + for (int64_t dim = 0; dim < outputType.getRank(); ++dim) { + if (getNewShape()[dim] != -1 && + getNewShape()[dim] != outputType.getShape()[dim]) { + return emitOpError() + << "newShape attribute (" << getNewShape()[dim] + << ") does not match output type (" << outputType.getShape()[dim] + << ") in dimension " << dim; } } } @@ -792,38 +900,34 @@ mlir::LogicalResult tosa::SliceOp::verify() { if (inputType.getRank() != outputType.getRank()) { return emitOpError() << "rank of input (" << inputType.getRank() - << ") and output (" - << outputType.getRank() - << ") must match"; + << ") and output (" << outputType.getRank() + << ") must match"; } if ((int64_t)getSize().size() != outputType.getRank()) { - return emitOpError() << "rank of size (" << getSize().size() - << ") and output (" - << outputType.getRank() - << ") must match"; + return emitOpError() << "rank of size (" << getSize().size() + << ") and output (" << outputType.getRank() + << ") must match"; } - for (int64_t dim=0; dim < outputType.getRank(); ++dim) { - if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) && - getSize()[dim] != outputType.getShape()[dim]) { + for (int64_t dim = 0; dim < outputType.getRank(); ++dim) { + if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) && + getSize()[dim] != outputType.getShape()[dim]) { return emitOpError() << "size attribute (" << getSize()[dim] << ") does not match output type (" << outputType.getShape()[dim] << ") in dimension " << dim; - } + } } if ((int64_t)getStart().size() != inputType.getRank()) { - return emitOpError() << "rank of start (" << getStart().size() - << ") and input (" - << inputType.getRank() - << ") must match"; + return emitOpError() << "rank of start (" << getStart().size() + << ") and input (" << inputType.getRank() + << ") must match"; } if ((int64_t)getSize().size() != inputType.getRank()) { - return emitOpError() << "rank of size (" << getSize().size() - << ") and input (" - << inputType.getRank() - << ") must match"; + return emitOpError() << "rank of size (" << getSize().size() + << ") and input (" << inputType.getRank() + << ") must match"; } for (int i = 0; i < outputType.getRank(); ++i) { @@ -860,6 +964,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); ShapeAdaptor permsShape = operands.getShape(1); + auto inputType = getElementTypeOrSelf(operands[0]); // If input rank and permutation length is unknown, the output rank is // unknown. @@ -869,8 +974,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( return success(); } - // This would imply the number of permutations does not match the rank of the - // input which is illegal. + // This would imply the number of permutations does not match the rank of + // the input which is illegal. if (permsShape.getDimSize(0) != inputShape.getRank()) { return failure(); } @@ -880,13 +985,15 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( SmallVector outputShape; if (!inputShape.hasRank()) { outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic); - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back( + ShapedTypeComponents(outputShape, inputType)); return success(); } // Rank-0 means no permutations matter. if (inputShape.getRank() == 0) { - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back( + ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -903,7 +1010,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( // permutation. if (allTheSame) { outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0)); - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back( + ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -917,7 +1025,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( } } - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -1069,6 +1177,7 @@ REDUCE_SHAPE_INFER(tosa::ReduceProdOp) REDUCE_SHAPE_INFER(tosa::ReduceSumOp) #undef REDUCE_SHAPE_INFER COMPATIBLE_RETURN_TYPES(tosa::ConcatOp) +COMPATIBLE_RETURN_TYPES(tosa::TransposeOp) #undef COMPATIBLE_RETURN_TYPES static LogicalResult NAryInferReturnTypes( @@ -1230,6 +1339,94 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( return success(); } +template +static LogicalResult verifyBinaryOpWithEqualRank(T op) { + auto input1ShapeType = llvm::cast(op.getInput1().getType()); + auto input2ShapeType = llvm::cast(op.getInput2().getType()); + + if (input1ShapeType.hasRank() && input2ShapeType.hasRank()) { + auto input1Rank = input1ShapeType.getRank(); + auto input2Rank = input2ShapeType.getRank(); + if (input1Rank != input2Rank) { + return op.emitOpError("both operands must have same rank."); + } + } + return success(); +} +LogicalResult tosa::MulOp::verify() { + auto result = verifyBinaryOpWithEqualRank(*this); + if (result.failed()) { + return result; + } + auto shiftAttr = getShiftAttr().getInt(); + auto input1ShapeType = llvm::cast(getInput1().getType()); + auto elementType = getElementTypeOrSelf(input1ShapeType); + if (!(elementType.isInteger(8) || elementType.isInteger(16))) { + if (shiftAttr != 0) { + return emitOpError( + "shift attribute should be 0 for non integer input types"); + } + } + return success(); +} +LogicalResult tosa::AddOp::verify() { + return verifyBinaryOpWithEqualRank(*this); +} +LogicalResult tosa::GreaterEqualOp::verify() { + return verifyBinaryOpWithEqualRank(*this); +} +template +LogicalResult verifyForSameRank(T op, ShapedType inputShape1, + ShapedType inputShape2) { + if (inputShape1.hasRank() && inputShape2.hasRank()) { + auto input1Rank = inputShape1.getRank(); + auto input2Rank = inputShape2.getRank(); + if (input1Rank != input2Rank) { + return op.emitOpError("both operands must have same rank."); + } + } + return success(); +} +LogicalResult tosa::SelectOp::verify() { + + auto input1ShapeType = llvm::cast(getOperand(0).getType()); + auto input2ShapeType = llvm::cast(getOperand(1).getType()); + auto input3ShapeType = llvm::cast(getOperand(2).getType()); + auto outputShapeType = llvm::cast(getResult().getType()); + + auto input2ETy = + llvm::cast(getOperand(1).getType()).getElementType(); + auto input3ETy = + llvm::cast(getOperand(2).getType()).getElementType(); + auto resultETy = getElementTypeOrSelf(getResult()); + // auto resultETy = llvm::cast(getResult()).getElementType(); + + auto result1 = verifyForSameRank(*this, input1ShapeType, input2ShapeType); + if (result1.failed()) { + return result1; + } + auto result2 = verifyForSameRank(*this, input1ShapeType, input3ShapeType); + if (result2.failed()) { + return result2; + } + auto result3 = verifyForSameRank(*this, input1ShapeType, outputShapeType); + if (result3.failed()) { + return result3; + } + if (input2ETy != input3ETy) { + return emitOpError("inputs should be of same type."); + } + if ((input2ETy != resultETy) || (input3ETy != resultETy)) { + return emitOpError("inputs and result should be of same type."); + } + + auto result = OpTrait::impl::verifyCompatibleOperandBroadcast(getOperation()); + if (result.failed()) { + return result; + } + return success(); +} + LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); } LogicalResult Conv3DOp::inferReturnTypeComponents( diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e285a9de1d66d..abd3e01464f63 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -45,6 +45,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens return %0 : tensor } + // ----- func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor { @@ -151,3 +152,246 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> { %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32> return %0 : tensor<100x100xf32> } + +// ----- +func.func @test_avg_pool2d_negative_kernel(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op kernel should be greater than one.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> +} +// ----- +func.func @test_avg_pool2d_negative_stride(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op stride should be greater than one.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> +} +// ----- +func.func @test_avg_pool2d_negative_pad(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op pad should be positive}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> +} +// ----- +func.func @test_avg_pool2d_kernel_lessthan_pad(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op pad must be less than kernel size}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> +} +// ----- +func.func @test_avg_pool2d_vert_stride_incorrect_mul(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op vertical stride is not in correct multiple.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> +} +// ----- +func.func @test_avg_pool2d_hor_stride_incorrect_mul(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op horizontal stride is not in correct multiple.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> + return %0 : tensor<1x7x4x9xi8> +} +// ----- +func.func @test_max_pool2d_hor_stride_incorrect_mul(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> { + // expected-error@+1 {{'tosa.max_pool2d' op horizontal stride is not in correct multiple.}} + %0 = "tosa.max_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> + return %0 : tensor<1x7x4x9xi8> +} +// ----- +func.func @test_avg_pool2d_output_height_incorrect(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x8x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op output height is not correct, should be 3.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x7x8x9xi8> + return %0 : tensor<1x7x8x9xi8> +} +// ----- +func.func @test_avg_pool2d_output_width_incorrect(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op output width is not correct, should be 3.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> + return %0 : tensor<1x3x8x9xi8> +} +// ----- +func.func @test_max_pool2d_output_width_incorrect(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> { + // expected-error@+1 {{'tosa.max_pool2d' op output width is not correct, should be 3.}} + %0 = "tosa.max_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> + return %0 : tensor<1x3x8x9xi8> +} +// ----- +func.func @test_add_incompabitble_type(%arg0: tensor<13x21xf32>, %arg1: tensor<13x21xi8>) -> tensor<13x21xf32> { + // expected-error@+1 {{'tosa.add' op requires the same element type for all operands and results}} + %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21xf32>, tensor<13x21xi8>) -> tensor<13x21xf32> + return %0 : tensor<13x21xf32> +} +// ----- +func.func @test_add_incorrect_output(%arg0: tensor<13x21xf32>, %arg1: tensor<13x21xf32>) -> tensor<13x2xf32> { + // expected-error@+1 {{'tosa.add' op result type '13x2' not broadcast compatible with broadcasted operands's shapes '13x21'}} + %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21xf32>, tensor<13x21xf32>) -> tensor<13x2xf32> + return %0 : tensor<13x2xf32> +} +// ----- +func.func @test_add_incorrect_output2(%arg0: tensor<13x21xf32>, %arg1: tensor<2x13x21xf32>) -> tensor<2x13x21xf32> { + // expected-error@+1 {{'tosa.add' op both operands must have same rank.}} + %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21xf32>, tensor<2x13x21xf32>) -> tensor<2x13x21xf32> + return %0 : tensor<2x13x21xf32> +} +// ----- +func.func @test_const_incorrect_output(%arg0 : index) -> tensor<4xi32> { + // expected-error@+1{{inferred shape of elements literal ([4]) does not match type ([3])}} + %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<3xi32>} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} +// ----- +func.func @test_greater_equal_incompatible(%arg0: tensor<13x1x3x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { + // expected-error@+1{{'tosa.greater_equal' op operands don't have broadcast-compatible shapes}} + %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x1x3x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} +// ----- +func.func @test_greater_equal_unequal_rank(%arg0: tensor<12x13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor { + // expected-error@+1{{'tosa.greater_equal' op both operands must have same rank.}} + %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<12x13x21x3xf32>, tensor<13x21x3xf32>) -> tensor + return %0 : tensor +} +// ----- +func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // expected-error@+1{{'tosa.greater_equal' op result #0 must be tensor of 1-bit signless integer values, but got 'tensor<13x21x3xf32>'}} + %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} +// ----- +func.func @test_mul_incompabitble_type(%arg0: tensor<13x21xf32>, %arg1: tensor<13x21xi8>) -> tensor<13x21xf32> { + // expected-error@+1 {{'tosa.mul' op requires the same element type for all operands and results}} + %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 }: (tensor<13x21xf32>, tensor<13x21xi8>) -> tensor<13x21xf32> + return %0 : tensor<13x21xf32> +} +// ----- +func.func @test_mul_unequal_rank(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x21x3xf32>) -> tensor { + // expected-error@+1{{'tosa.mul' op both operands must have same rank.}} + %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xf32>, tensor<13x1x21x3xf32>) -> tensor + return %0 : tensor +} +// ----- +func.func @test_add_unequal_rank(%arg0: tensor<3x4x8400xf32>, %arg1: tensor<8400xf32>) -> tensor<3x4x8400xf32> { + // expected-error@+1{{'tosa.add' op both operands must have same rank.}} + %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4x8400xf32>, tensor<8400xf32>) -> tensor<3x4x8400xf32> + return %0 : tensor<3x4x8400xf32> +} + +// ----- +func.func @test_mul_incompatible(%arg0: tensor<3x4x8400xf32>, %arg1: tensor<3x8400xf32>) -> tensor<1x4x8400xf32> { + // expected-error@+1{{'tosa.mul' op operands don't have broadcast-compatible shapes}} + %0 = "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<3x4x8400xf32>, tensor<3x8400xf32>) -> tensor<1x4x8400xf32> + return %0 : tensor<1x4x8400xf32> +} +// ----- +func.func @test_mul_need_shift(%arg0: tensor<3x4x8400xf32>, %arg1: tensor<3x4x8400xf32>) -> tensor<3x4x8400xf32> { + // expected-error@+1{{'tosa.mul' op requires attribute 'shift'}} + %0 = "tosa.mul"(%arg0, %arg1) : (tensor<3x4x8400xf32>, tensor<3x4x8400xf32>) -> tensor<3x4x8400xf32> + return %0 : tensor<3x4x8400xf32> +} +// ----- +func.func @test_mul_nonzero_shift(%arg0: tensor<3x4x8400xf32>, %arg1: tensor<3x4x8400xf32>) -> tensor<3x4x8400xf32> { + // expected-error@+1{{'tosa.mul' op shift attribute should be 0 for non integer input types}} + %0 = "tosa.mul"(%arg0, %arg1) {shift = 3 : i32}: (tensor<3x4x8400xf32>, tensor<3x4x8400xf32>) -> tensor<3x4x8400xf32> + return %0 : tensor<3x4x8400xf32> +} + + // ----- +func.func @test_select_unequal_rank_inputs(%arg0: tensor<2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { + // expected-error@+1{{'tosa.select' op both operands must have same rank.}} + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} +// ----- +func.func @test_select_unequal_rank_inputs2(%arg0: tensor<1x2xi1>, %arg1: tensor<1x3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { + // expected-error@+1{{'tosa.select' op both operands must have same rank.}} + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<1x3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} +// ----- +func.func @test_select_unequal_rank_inputs3(%arg0: tensor<1x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<1x3x2xf32>) -> tensor<3x2xf32> { + // expected-error@+1{{'tosa.select' op both operands must have same rank.}} + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<3x2xf32>, tensor<1x3x2xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} +// ----- +func.func @test_select_not_boardcastable_arg1(%arg0: tensor<2x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { + // expected-error@+1{{'tosa.select' op operands don't have broadcast-compatible shapes}} + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<2x2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} +// ----- +func.func @test_select_not_boardcastable_result(%arg0: tensor<1x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<4x2xf32> { + // expected-error@+1{{'tosa.select' op result type '4x2' not broadcast compatible with broadcasted operands's shapes '3x2'}} + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<4x2xf32> + return %0 : tensor<4x2xf32> +} +// ----- +func.func @test_select_not_boardcastable_arg3(%arg0: tensor<1x2xi1>, %arg1: tensor<2x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { + // expected-error@+1{{'tosa.select' op operands don't have broadcast-compatible shapes}} + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<2x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} +// ----- +func.func @test_select_incompatible_1(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xi8> { + // expected-error@+1{{'tosa.select' op inputs and result should be of same type.}} + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi8> + return %0 : tensor<13x21x3xi8> +} +// ----- +func.func @test_select_incompatible_2(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xi8>) -> tensor<13x21x3xf32> { + // expected-error@+1{{'tosa.select' op inputs should be of same type.}} + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} +// ----- +func.func @test_select_incompatible_3(%arg0: tensor<1x1x1xi8>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xi8>) -> tensor<13x21x3xf32> { + // expected-error@+1{{'tosa.select' op operand #0 must be tensor of 1-bit signless integer values, but got 'tensor<1x1x1xi8>'}} + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi8>, tensor<13x21x3xf32>, tensor<13x21x3xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} +// ----- +func.func @test_transpose_incorrect_result_shape(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x20xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // expected-error@+2{{'tosa.transpose' op failed to infer returned types}} + // expected-error@+1{{'tosa.transpose' op inferred type(s) 'tensor<3x13x21xf32>' are incompatible with return type(s) of operation 'tensor<3x13x20xf32>'}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x20xf32> + return %1 : tensor<3x13x20xf32> +} +// ----- +func.func @test_transpose_incorrect_result_rank(%arg0: tensor<13x21x3xf32>) -> tensor<3x13xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // expected-error@+2{{'tosa.transpose' op failed to infer returned types}} + // expected-error@+1{{'tosa.transpose' op inferred type(s) 'tensor<3x13x21xf32>' are incompatible with return type(s) of operation 'tensor<3x13xf32>'}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13xf32> + return %1 : tensor<3x13xf32> +} +// ----- +func.func @test_transpose_incorrect_result_type(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xi8> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1{{'tosa.transpose' op inferred type(s) 'tensor<3x13x21xf32>' are incompatible with return type(s) of operation 'tensor<3x13x21xi8>'}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xi8> + return %1 : tensor<3x13x21xi8> +} +// ----- +func.func @test_transpose_high_rank_perm(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21x4xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + // expected-error@+1 {{failed to infer returned types}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<3x13x21x4xf32> + return %1 : tensor<3x13x21x4xf32> +} +// ----- +// CHECK-LABEL: transpose +func.func @test_transpose_low_rank_perm(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21x4xf32> { + %0 = "tosa.const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error@+1 {{failed to infer returned types}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<3x13x21x4xf32> + return %1 : tensor<3x13x21x4xf32> +} +// ----- +// CHECK-LABEL: transpose +func.func @test_transpose_result_high_rank(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21x4xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1 {{'tosa.transpose' op inferred type(s) 'tensor<3x13x21xf32>' are incompatible with return type(s) of operation 'tensor<3x13x21x4xf32>'}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21x4xf32> + return %1 : tensor<3x13x21x4xf32> +} \ No newline at end of file diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 72f020336ff05..95f169f4923f2 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -215,10 +215,15 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te // ----- // CHECK-LABEL: mul func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { - %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + %0 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 } : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } - +// ----- +// CHECK-LABEL: mul +func.func @test_mul_nonzero_shift(%arg0: tensor<3x4x8400xi8>, %arg1: tensor<3x4x8400xi8>) -> tensor<3x4x8400xi8> { + %0 = "tosa.mul"(%arg0, %arg1) {shift = 3 : i32}: (tensor<3x4x8400xi8>, tensor<3x4x8400xi8>) -> tensor<3x4x8400xi8> + return %0 : tensor<3x4x8400xi8> +} // ----- // CHECK-LABEL: pow func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> { @@ -323,7 +328,12 @@ func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } - +// ----- +// CHECK-LABEL: select +func.func @test_select_boardcastable(%arg0: tensor<1x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} // ----- // CHECK-LABEL: equal From 398a6e452840e458f650243ecaa1ce039630576a Mon Sep 17 00:00:00 2001 From: chaitany Date: Fri, 12 Jan 2024 20:20:12 +0530 Subject: [PATCH 02/12] refactor: removing the mul verifier, restoring original formatting and incorporating the review comments --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 8 +- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 122 +++++++------------ mlir/test/Dialect/Tosa/invalid.mlir | 107 ++++++++++------ mlir/test/Dialect/Tosa/ops.mlir | 12 +- 4 files changed, 120 insertions(+), 129 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index f51d496dcfa5c..064b9503ac410 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -781,13 +781,10 @@ def Tosa_MulOp : Tosa_ElemWiseBinaryOp<"mul", [Commutative]> { ); let results = (outs - Tosa_Tensor:$output + Tosa_Tensor:$output ); let hasFolder = 1; - - let hasVerifier = 1; - } //===----------------------------------------------------------------------===// @@ -1623,7 +1620,8 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [ }]; let hasCanonicalizer = 1; - let hasFolder = 1;} + let hasFolder = 1; +} //===----------------------------------------------------------------------===// // TOSA Spec Section 2.10 diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 806df575cdea4..5076cc71fe181 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -100,8 +100,7 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// -template -static LogicalResult verifyConvOp(T op) { +template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast(op.getInput().getType()); auto weightType = llvm::dyn_cast(op.getWeight().getType()); @@ -141,10 +140,8 @@ static LogicalResult verifyConvOp(T op) { return success(); } -template -static LogicalResult verifyPoolOp(T op) { - auto inputETy = - llvm::cast(op.getInput().getType()).getElementType(); +template static LogicalResult verifyPoolOp(T op) { + auto inputETy = llvm::cast(op.getInput().getType()).getElementType(); auto resultETy = llvm::cast(op.getType()).getElementType(); if (auto quantType = @@ -214,7 +211,6 @@ static LogicalResult verifyPoolOp(T op) { */ if ((outputHeight != (height / stride[0]) + 1)) { - return op.emitOpError("output height is not correct, should be ") << (height / stride[0]) + 1 << "."; } @@ -236,22 +232,6 @@ static LogicalResult verifyPoolOp(T op) { return op.emitOpError("input/output element types are incompatible."); } -// LogicalResult tosa::AddOp::verify() { - -// auto input1ShapedType = llvm::cast(getInput1().getType()); -// auto input2ShapedType = llvm::cast(getInput2().getType()); -// auto resultShapedType = llvm::cast(getType()); - -// if (input1ShapedType.hasStaticShape() && input2ShapedType.hasStaticShape() -// && -// resultShapedType.hasStaticShape()) { -// if (input1ShapedType.getRank() != input2ShapedType.getRank()) { -// return emitOpError("input tensors must be of equal rank."); -// } -// return success(); -// } -// return success(); -// } LogicalResult tosa::MaxPool2dOp::verify() { return verifyPoolOp(*this); } LogicalResult tosa::AvgPool2dOp::verify() { @@ -269,14 +249,18 @@ LogicalResult tosa::AvgPool2dOp::verify() { auto accType = getAccType(); if (llvm::isa(inputETy) && !accType.isInteger(32)) return emitOpError("accumulator type for integer tensor is not i32"); + auto result = verifyPoolOp(*this); if (result.succeeded()) { if ((inputETy.isF16()) && !(accType.isF16() || accType.isF32())) return emitOpError("accumulator type for f16 tensor is not f16/f32"); + if ((inputETy.isBF16()) && !(accType.isF32())) return emitOpError("accumulator type for bf16 tensor is not f32"); + if (inputETy.isF32() && !accType.isF32()) return emitOpError("accumulator type for f32 tensor is not f32"); + } return result; } @@ -309,8 +293,7 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, } } -/// Handles tosa.transpose_conv2d which has outpad and output shape -/// attributes. +/// Handles tosa.transpose_conv2d which has outpad and output shape attributes. static void buildTransConvOpWithQuantInfo( OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, @@ -347,9 +330,9 @@ static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, } } -/// The tosa.matmul op is also intended to be generated where a -/// fully_connected op must be constructed where the weight is not a constant. -/// In this case, the fully_connected op must be expressed using matmul. +/// The tosa.matmul op is also intended to be generated where a fully_connected +/// op must be constructed where the weight is not a constant. In this case, +/// the fully_connected op must be expressed using matmul. /// TODO: Add link to the leglization document explaining this. static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, @@ -384,9 +367,9 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder, } } -/// Both the tosa.avg_pool2d and unary ops use the same -/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it -/// has additional parameters not part of the unary ops. +/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr +/// but avg_pool operator has its own builder as it has additional parameters +/// not part of the unary ops. static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, @@ -453,8 +436,8 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, for (int i = 0, e = operands.size(); i != e; ++i) { auto shape = operands.getShape(i); if (!shape.hasRank()) { - // TODO(jennik): Update function to have better case handling for - // invalid operands and for ranked tensors. + // TODO(jennik): Update function to have better case handling for invalid + // operands and for ranked tensors. return failure(); } outRank = std::max(outRank, shape.getRank()); @@ -709,8 +692,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( return success(); } - // If the input rank is unknown we can info the output rank using the - // padding shape's first dim. + // If the input rank is unknown we can info the output rank using the padding + // shape's first dim. if (!inputShape.hasRank()) { if (paddingShape.isDynamicDim(0)) { inferredReturnShapes.push_back(ShapedTypeComponents()); @@ -876,17 +859,17 @@ mlir::LogicalResult tosa::ReshapeOp::verify() { if ((int64_t)getNewShape().size() != outputType.getRank()) { return emitOpError() << "rank of newShape (" << getNewShape().size() - << ") and output (" << outputType.getRank() - << ") must match"; + << ") and output (" + << outputType.getRank() + << ") must match"; } - for (int64_t dim = 0; dim < outputType.getRank(); ++dim) { - if (getNewShape()[dim] != -1 && - getNewShape()[dim] != outputType.getShape()[dim]) { - return emitOpError() - << "newShape attribute (" << getNewShape()[dim] - << ") does not match output type (" << outputType.getShape()[dim] - << ") in dimension " << dim; + for (int64_t dim=0; dim < outputType.getRank(); ++dim) { + if (getNewShape()[dim] != -1 && getNewShape()[dim] != outputType.getShape()[dim]) { + return emitOpError() << "newShape attribute (" << getNewShape()[dim] + << ") does not match output type (" + << outputType.getShape()[dim] + << ") in dimension " << dim; } } } @@ -900,16 +883,18 @@ mlir::LogicalResult tosa::SliceOp::verify() { if (inputType.getRank() != outputType.getRank()) { return emitOpError() << "rank of input (" << inputType.getRank() - << ") and output (" << outputType.getRank() - << ") must match"; + << ") and output (" + << outputType.getRank() + << ") must match"; } if ((int64_t)getSize().size() != outputType.getRank()) { return emitOpError() << "rank of size (" << getSize().size() - << ") and output (" << outputType.getRank() - << ") must match"; + << ") and output (" + << outputType.getRank() + << ") must match"; } - for (int64_t dim = 0; dim < outputType.getRank(); ++dim) { + for (int64_t dim=0; dim < outputType.getRank(); ++dim) { if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) && getSize()[dim] != outputType.getShape()[dim]) { return emitOpError() << "size attribute (" << getSize()[dim] @@ -974,8 +959,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( return success(); } - // This would imply the number of permutations does not match the rank of - // the input which is illegal. + // This would imply the number of permutations does not match the rank of the + // input which is illegal. if (permsShape.getDimSize(0) != inputShape.getRank()) { return failure(); } @@ -985,15 +970,13 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( SmallVector outputShape; if (!inputShape.hasRank()) { outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic); - inferredReturnShapes.push_back( - ShapedTypeComponents(outputShape, inputType)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } // Rank-0 means no permutations matter. if (inputShape.getRank() == 0) { - inferredReturnShapes.push_back( - ShapedTypeComponents(outputShape, inputType)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -1010,8 +993,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( // permutation. if (allTheSame) { outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0)); - inferredReturnShapes.push_back( - ShapedTypeComponents(outputShape, inputType)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -1339,8 +1321,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( return success(); } -template -static LogicalResult verifyBinaryOpWithEqualRank(T op) { +template static LogicalResult verifyBinaryOpWithEqualRank(T op) { auto input1ShapeType = llvm::cast(op.getInput1().getType()); auto input2ShapeType = llvm::cast(op.getInput2().getType()); @@ -1353,28 +1334,15 @@ static LogicalResult verifyBinaryOpWithEqualRank(T op) { } return success(); } -LogicalResult tosa::MulOp::verify() { - auto result = verifyBinaryOpWithEqualRank(*this); - if (result.failed()) { - return result; - } - auto shiftAttr = getShiftAttr().getInt(); - auto input1ShapeType = llvm::cast(getInput1().getType()); - auto elementType = getElementTypeOrSelf(input1ShapeType); - if (!(elementType.isInteger(8) || elementType.isInteger(16))) { - if (shiftAttr != 0) { - return emitOpError( - "shift attribute should be 0 for non integer input types"); - } - } - return success(); -} + LogicalResult tosa::AddOp::verify() { return verifyBinaryOpWithEqualRank(*this); } + LogicalResult tosa::GreaterEqualOp::verify() { return verifyBinaryOpWithEqualRank(*this); } + template LogicalResult verifyForSameRank(T op, ShapedType inputShape1, ShapedType inputShape2) { @@ -1387,6 +1355,7 @@ LogicalResult verifyForSameRank(T op, ShapedType inputShape1, } return success(); } + LogicalResult tosa::SelectOp::verify() { auto input1ShapeType = llvm::cast(getOperand(0).getType()); @@ -1398,8 +1367,7 @@ LogicalResult tosa::SelectOp::verify() { llvm::cast(getOperand(1).getType()).getElementType(); auto input3ETy = llvm::cast(getOperand(2).getType()).getElementType(); - auto resultETy = getElementTypeOrSelf(getResult()); - // auto resultETy = llvm::cast(getResult()).getElementType(); + auto resultETy = getElementTypeOrSelf(getResult()); auto result1 = verifyForSameRank(*this, input1ShapeType, input2ShapeType); if (result1.failed()) { diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index abd3e01464f63..87afad3eae66f 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -45,7 +45,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens return %0 : tensor } - // ----- func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor { @@ -154,120 +153,143 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> { } // ----- + func.func @test_avg_pool2d_negative_kernel(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { // expected-error@+1 {{'tosa.avg_pool2d' op kernel should be greater than one.}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> return %0 : tensor<1x7x7x9xi8> } + // ----- + func.func @test_avg_pool2d_negative_stride(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { // expected-error@+1 {{'tosa.avg_pool2d' op stride should be greater than one.}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> return %0 : tensor<1x7x7x9xi8> } + // ----- + func.func @test_avg_pool2d_negative_pad(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { // expected-error@+1 {{'tosa.avg_pool2d' op pad should be positive}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> return %0 : tensor<1x7x7x9xi8> } + // ----- + func.func @test_avg_pool2d_kernel_lessthan_pad(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { // expected-error@+1 {{'tosa.avg_pool2d' op pad must be less than kernel size}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> return %0 : tensor<1x7x7x9xi8> } + // ----- -func.func @test_avg_pool2d_vert_stride_incorrect_mul(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + +func.func @test_avg_pool2d_vert_stride_incorrect_multiple(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { // expected-error@+1 {{'tosa.avg_pool2d' op vertical stride is not in correct multiple.}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> return %0 : tensor<1x7x7x9xi8> } + // ----- -func.func @test_avg_pool2d_hor_stride_incorrect_mul(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> { + +func.func @test_avg_pool2d_hor_stride_incorrect_multiple(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> { // expected-error@+1 {{'tosa.avg_pool2d' op horizontal stride is not in correct multiple.}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> return %0 : tensor<1x7x4x9xi8> } + // ----- -func.func @test_max_pool2d_hor_stride_incorrect_mul(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> { + +func.func @test_max_pool2d_hor_stride_incorrect_multiple(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> { // expected-error@+1 {{'tosa.max_pool2d' op horizontal stride is not in correct multiple.}} %0 = "tosa.max_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> return %0 : tensor<1x7x4x9xi8> } + // ----- + func.func @test_avg_pool2d_output_height_incorrect(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x8x9xi8> { // expected-error@+1 {{'tosa.avg_pool2d' op output height is not correct, should be 3.}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x7x8x9xi8> return %0 : tensor<1x7x8x9xi8> } + // ----- + func.func @test_avg_pool2d_output_width_incorrect(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> { // expected-error@+1 {{'tosa.avg_pool2d' op output width is not correct, should be 3.}} %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> return %0 : tensor<1x3x8x9xi8> } + // ----- + func.func @test_max_pool2d_output_width_incorrect(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> { // expected-error@+1 {{'tosa.max_pool2d' op output width is not correct, should be 3.}} %0 = "tosa.max_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> return %0 : tensor<1x3x8x9xi8> } + // ----- + func.func @test_add_incompabitble_type(%arg0: tensor<13x21xf32>, %arg1: tensor<13x21xi8>) -> tensor<13x21xf32> { // expected-error@+1 {{'tosa.add' op requires the same element type for all operands and results}} %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21xf32>, tensor<13x21xi8>) -> tensor<13x21xf32> return %0 : tensor<13x21xf32> } + // ----- + func.func @test_add_incorrect_output(%arg0: tensor<13x21xf32>, %arg1: tensor<13x21xf32>) -> tensor<13x2xf32> { // expected-error@+1 {{'tosa.add' op result type '13x2' not broadcast compatible with broadcasted operands's shapes '13x21'}} %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21xf32>, tensor<13x21xf32>) -> tensor<13x2xf32> return %0 : tensor<13x2xf32> } + // ----- + func.func @test_add_incorrect_output2(%arg0: tensor<13x21xf32>, %arg1: tensor<2x13x21xf32>) -> tensor<2x13x21xf32> { // expected-error@+1 {{'tosa.add' op both operands must have same rank.}} %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21xf32>, tensor<2x13x21xf32>) -> tensor<2x13x21xf32> return %0 : tensor<2x13x21xf32> } + // ----- + func.func @test_const_incorrect_output(%arg0 : index) -> tensor<4xi32> { // expected-error@+1{{inferred shape of elements literal ([4]) does not match type ([3])}} %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<3xi32>} : () -> tensor<4xi32> return %0 : tensor<4xi32> } + // ----- + func.func @test_greater_equal_incompatible(%arg0: tensor<13x1x3x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { // expected-error@+1{{'tosa.greater_equal' op operands don't have broadcast-compatible shapes}} %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x1x3x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> return %0 : tensor<13x21x3xi1> } + // ----- + func.func @test_greater_equal_unequal_rank(%arg0: tensor<12x13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor { // expected-error@+1{{'tosa.greater_equal' op both operands must have same rank.}} %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<12x13x21x3xf32>, tensor<13x21x3xf32>) -> tensor return %0 : tensor } + // ----- + func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // expected-error@+1{{'tosa.greater_equal' op result #0 must be tensor of 1-bit signless integer values, but got 'tensor<13x21x3xf32>'}} %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } + // ----- -func.func @test_mul_incompabitble_type(%arg0: tensor<13x21xf32>, %arg1: tensor<13x21xi8>) -> tensor<13x21xf32> { - // expected-error@+1 {{'tosa.mul' op requires the same element type for all operands and results}} - %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 }: (tensor<13x21xf32>, tensor<13x21xi8>) -> tensor<13x21xf32> - return %0 : tensor<13x21xf32> -} -// ----- -func.func @test_mul_unequal_rank(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x21x3xf32>) -> tensor { - // expected-error@+1{{'tosa.mul' op both operands must have same rank.}} - %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xf32>, tensor<13x1x21x3xf32>) -> tensor - return %0 : tensor -} -// ----- + func.func @test_add_unequal_rank(%arg0: tensor<3x4x8400xf32>, %arg1: tensor<8400xf32>) -> tensor<3x4x8400xf32> { // expected-error@+1{{'tosa.add' op both operands must have same rank.}} %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4x8400xf32>, tensor<8400xf32>) -> tensor<3x4x8400xf32> @@ -275,87 +297,89 @@ func.func @test_add_unequal_rank(%arg0: tensor<3x4x8400xf32>, %arg1: tensor<8400 } // ----- -func.func @test_mul_incompatible(%arg0: tensor<3x4x8400xf32>, %arg1: tensor<3x8400xf32>) -> tensor<1x4x8400xf32> { - // expected-error@+1{{'tosa.mul' op operands don't have broadcast-compatible shapes}} - %0 = "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<3x4x8400xf32>, tensor<3x8400xf32>) -> tensor<1x4x8400xf32> - return %0 : tensor<1x4x8400xf32> -} -// ----- -func.func @test_mul_need_shift(%arg0: tensor<3x4x8400xf32>, %arg1: tensor<3x4x8400xf32>) -> tensor<3x4x8400xf32> { - // expected-error@+1{{'tosa.mul' op requires attribute 'shift'}} - %0 = "tosa.mul"(%arg0, %arg1) : (tensor<3x4x8400xf32>, tensor<3x4x8400xf32>) -> tensor<3x4x8400xf32> - return %0 : tensor<3x4x8400xf32> -} -// ----- -func.func @test_mul_nonzero_shift(%arg0: tensor<3x4x8400xf32>, %arg1: tensor<3x4x8400xf32>) -> tensor<3x4x8400xf32> { - // expected-error@+1{{'tosa.mul' op shift attribute should be 0 for non integer input types}} - %0 = "tosa.mul"(%arg0, %arg1) {shift = 3 : i32}: (tensor<3x4x8400xf32>, tensor<3x4x8400xf32>) -> tensor<3x4x8400xf32> - return %0 : tensor<3x4x8400xf32> -} - // ----- func.func @test_select_unequal_rank_inputs(%arg0: tensor<2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { // expected-error@+1{{'tosa.select' op both operands must have same rank.}} %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> } + // ----- + func.func @test_select_unequal_rank_inputs2(%arg0: tensor<1x2xi1>, %arg1: tensor<1x3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { // expected-error@+1{{'tosa.select' op both operands must have same rank.}} %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<1x3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> } + // ----- + func.func @test_select_unequal_rank_inputs3(%arg0: tensor<1x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<1x3x2xf32>) -> tensor<3x2xf32> { // expected-error@+1{{'tosa.select' op both operands must have same rank.}} %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<3x2xf32>, tensor<1x3x2xf32>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> -} +} + // ----- + func.func @test_select_not_boardcastable_arg1(%arg0: tensor<2x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { // expected-error@+1{{'tosa.select' op operands don't have broadcast-compatible shapes}} %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<2x2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> } + // ----- + func.func @test_select_not_boardcastable_result(%arg0: tensor<1x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<4x2xf32> { // expected-error@+1{{'tosa.select' op result type '4x2' not broadcast compatible with broadcasted operands's shapes '3x2'}} %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<4x2xf32> return %0 : tensor<4x2xf32> } + // ----- + func.func @test_select_not_boardcastable_arg3(%arg0: tensor<1x2xi1>, %arg1: tensor<2x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { // expected-error@+1{{'tosa.select' op operands don't have broadcast-compatible shapes}} %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<2x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> } + // ----- + func.func @test_select_incompatible_1(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xi8> { // expected-error@+1{{'tosa.select' op inputs and result should be of same type.}} %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi8> return %0 : tensor<13x21x3xi8> } + // ----- + func.func @test_select_incompatible_2(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xi8>) -> tensor<13x21x3xf32> { // expected-error@+1{{'tosa.select' op inputs should be of same type.}} %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } + // ----- + func.func @test_select_incompatible_3(%arg0: tensor<1x1x1xi8>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xi8>) -> tensor<13x21x3xf32> { // expected-error@+1{{'tosa.select' op operand #0 must be tensor of 1-bit signless integer values, but got 'tensor<1x1x1xi8>'}} %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi8>, tensor<13x21x3xf32>, tensor<13x21x3xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } + // ----- + func.func @test_transpose_incorrect_result_shape(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x20xf32> { %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> // expected-error@+2{{'tosa.transpose' op failed to infer returned types}} // expected-error@+1{{'tosa.transpose' op inferred type(s) 'tensor<3x13x21xf32>' are incompatible with return type(s) of operation 'tensor<3x13x20xf32>'}} %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x20xf32> return %1 : tensor<3x13x20xf32> -} +} + // ----- + func.func @test_transpose_incorrect_result_rank(%arg0: tensor<13x21x3xf32>) -> tensor<3x13xf32> { %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> // expected-error@+2{{'tosa.transpose' op failed to infer returned types}} @@ -363,7 +387,9 @@ func.func @test_transpose_incorrect_result_rank(%arg0: tensor<13x21x3xf32>) -> t %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13xf32> return %1 : tensor<3x13xf32> } + // ----- + func.func @test_transpose_incorrect_result_type(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xi8> { %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> // expected-error@+2 {{failed to infer returned types}} @@ -371,27 +397,30 @@ func.func @test_transpose_incorrect_result_type(%arg0: tensor<13x21x3xf32>) -> t %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xi8> return %1 : tensor<3x13x21xi8> } + // ----- + func.func @test_transpose_high_rank_perm(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21x4xf32> { %0 = "tosa.const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32> // expected-error@+1 {{failed to infer returned types}} %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<3x13x21x4xf32> return %1 : tensor<3x13x21x4xf32> } + // ----- -// CHECK-LABEL: transpose + func.func @test_transpose_low_rank_perm(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21x4xf32> { %0 = "tosa.const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> // expected-error@+1 {{failed to infer returned types}} %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<3x13x21x4xf32> return %1 : tensor<3x13x21x4xf32> } + // ----- -// CHECK-LABEL: transpose func.func @test_transpose_result_high_rank(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21x4xf32> { %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{'tosa.transpose' op inferred type(s) 'tensor<3x13x21xf32>' are incompatible with return type(s) of operation 'tensor<3x13x21x4xf32>'}} %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21x4xf32> return %1 : tensor<3x13x21x4xf32> -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 95f169f4923f2..25c1ea2fed400 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -215,15 +215,10 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te // ----- // CHECK-LABEL: mul func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { - %0 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 } : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } -// ----- -// CHECK-LABEL: mul -func.func @test_mul_nonzero_shift(%arg0: tensor<3x4x8400xi8>, %arg1: tensor<3x4x8400xi8>) -> tensor<3x4x8400xi8> { - %0 = "tosa.mul"(%arg0, %arg1) {shift = 3 : i32}: (tensor<3x4x8400xi8>, tensor<3x4x8400xi8>) -> tensor<3x4x8400xi8> - return %0 : tensor<3x4x8400xi8> -} + // ----- // CHECK-LABEL: pow func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> { @@ -328,9 +323,10 @@ func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } + // ----- // CHECK-LABEL: select -func.func @test_select_boardcastable(%arg0: tensor<1x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { +func.func @test_select_broardcastable(%arg0: tensor<1x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> } From 0a01023fcae791e90134e124b86afa1b94e801a2 Mon Sep 17 00:00:00 2001 From: chaitany Date: Fri, 12 Jan 2024 20:40:26 +0530 Subject: [PATCH 03/12] refactor: formatting changes --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 4 ++-- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 15 +++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 064b9503ac410..a6b85c0ba2c29 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -499,7 +499,7 @@ def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> { Tosa_Tensor:$output ); - let hasFolder = 1; + let hasFolder = 1; let hasVerifier = 1; } @@ -781,7 +781,7 @@ def Tosa_MulOp : Tosa_ElemWiseBinaryOp<"mul", [Commutative]> { ); let results = (outs - Tosa_Tensor:$output + Tosa_Tensor:$output ); let hasFolder = 1; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 5076cc71fe181..b7b3de3958754 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -905,16 +905,19 @@ mlir::LogicalResult tosa::SliceOp::verify() { } if ((int64_t)getStart().size() != inputType.getRank()) { - return emitOpError() << "rank of start (" << getStart().size() - << ") and input (" << inputType.getRank() - << ") must match"; + return emitOpError() << "rank of start (" << getStart().size() + << ") and input (" + << inputType.getRank() + << ") must match"; } if ((int64_t)getSize().size() != inputType.getRank()) { - return emitOpError() << "rank of size (" << getSize().size() - << ") and input (" << inputType.getRank() - << ") must match"; + return emitOpError() << "rank of size (" << getSize().size() + << ") and input (" + << inputType.getRank() + << ") must match"; } + for (int i = 0; i < outputType.getRank(); ++i) { auto dimSize = inputType.getShape()[i]; if (getSize()[i] != -1 && dimSize != ShapedType::kDynamic && From 292eeaa191a553584e38e01e9d6eaca874b537a2 Mon Sep 17 00:00:00 2001 From: chaitany Date: Fri, 12 Jan 2024 20:54:18 +0530 Subject: [PATCH 04/12] refactor: formatting changes --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index b7b3de3958754..2ebbfde9e8845 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -858,10 +858,10 @@ mlir::LogicalResult tosa::ReshapeOp::verify() { } if ((int64_t)getNewShape().size() != outputType.getRank()) { - return emitOpError() << "rank of newShape (" << getNewShape().size() - << ") and output (" - << outputType.getRank() - << ") must match"; + return emitOpError() << "rank of newShape (" << getNewShape().size() + << ") and output (" + << outputType.getRank() + << ") must match"; } for (int64_t dim=0; dim < outputType.getRank(); ++dim) { @@ -889,19 +889,19 @@ mlir::LogicalResult tosa::SliceOp::verify() { } if ((int64_t)getSize().size() != outputType.getRank()) { - return emitOpError() << "rank of size (" << getSize().size() - << ") and output (" - << outputType.getRank() - << ") must match"; + return emitOpError() << "rank of size (" << getSize().size() + << ") and output (" + << outputType.getRank() + << ") must match"; } for (int64_t dim=0; dim < outputType.getRank(); ++dim) { - if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) && - getSize()[dim] != outputType.getShape()[dim]) { + if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) && + getSize()[dim] != outputType.getShape()[dim]) { return emitOpError() << "size attribute (" << getSize()[dim] << ") does not match output type (" << outputType.getShape()[dim] << ") in dimension " << dim; - } + } } if ((int64_t)getStart().size() != inputType.getRank()) { @@ -917,7 +917,6 @@ mlir::LogicalResult tosa::SliceOp::verify() { << ") must match"; } - for (int i = 0; i < outputType.getRank(); ++i) { auto dimSize = inputType.getShape()[i]; if (getSize()[i] != -1 && dimSize != ShapedType::kDynamic && From 9d9be06fc27792f5bbe9ffbd280189c6c70d45d3 Mon Sep 17 00:00:00 2001 From: chaitany Date: Tue, 16 Jan 2024 10:48:58 +0530 Subject: [PATCH 05/12] refactor: i32 tensor is not supported as input for tosa.max_pool2d --- mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index 901dcf78380bf..ec1666e6192bd 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -194,10 +194,10 @@ func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () { } // CHECK-LABEL: @max_pool_i32 -func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () { +func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi16>) -> () { // CHECK: arith.constant -2147483648 // CHECK: linalg.pooling_nhwc_max - %0 = "tosa.max_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>) + %0 = "tosa.max_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi16>) -> (tensor<1x4x32x62xi16>) return } From 78b73eb3efb88c5bfd16bea9ed80cbbd0dc3dfaf Mon Sep 17 00:00:00 2001 From: chaitany Date: Tue, 16 Jan 2024 12:30:34 +0530 Subject: [PATCH 06/12] refactor: removing the i32 version of tosa.max_pool2d test --- .../Conversion/TosaToLinalg/tosa-to-linalg-named.mlir | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index ec1666e6192bd..a255e795121b4 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -193,14 +193,6 @@ func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () { return } -// CHECK-LABEL: @max_pool_i32 -func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi16>) -> () { - // CHECK: arith.constant -2147483648 - // CHECK: linalg.pooling_nhwc_max - %0 = "tosa.max_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi16>) -> (tensor<1x4x32x62xi16>) - return -} - // ----- // CHECK-LABEL: @avg_pool_f32 From 0ccab4afa0acc1871874ba6e9dbf2d99f0a81cef Mon Sep 17 00:00:00 2001 From: chaitany Date: Wed, 17 Jan 2024 21:13:47 +0530 Subject: [PATCH 07/12] refator: to fix the failing tests --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 42 +++++- .../TosaToLinalg/tosa-to-linalg-named.mlir | 8 ++ .../TosaToLinalg/tosa-to-linalg.mlir | 6 +- mlir/test/Dialect/Tosa/inlining.mlir | 14 +- mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 120 +++++++++--------- 5 files changed, 117 insertions(+), 73 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 2ebbfde9e8845..72975f4aefdb2 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -229,6 +229,8 @@ template static LogicalResult verifyPoolOp(T op) { return success(); if (inputETy.isInteger(16) && resultETy.isInteger(16)) return success(); + if (inputETy.isInteger(32) && resultETy.isInteger(32)) + return success(); return op.emitOpError("input/output element types are incompatible."); } @@ -260,7 +262,6 @@ LogicalResult tosa::AvgPool2dOp::verify() { if (inputETy.isF32() && !accType.isF32()) return emitOpError("accumulator type for f32 tensor is not f32"); - } return result; } @@ -803,6 +804,33 @@ bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]); } +bool tosa::TransposeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + + if (l.size() != r.size() || l.size() != 1) + return false; + + auto left = getElementTypeOrSelf(l[0]); + auto right = getElementTypeOrSelf(r[0]); + + if (auto quantType = llvm::dyn_cast(left)) + left = quantType.getStorageType(); + + if (auto quantType = llvm::dyn_cast(left)) + left = quantType.getStorageType(); + + if (auto quantType = llvm::dyn_cast(right)){ + right = quantType.getStorageType(); + } + + if (auto quantType = llvm::dyn_cast(right)){ + right = quantType.getStorageType(); + } + + if (left != right) + return false; + return succeeded(verifyCompatibleShape(l[0], r[0])); +} + LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, @@ -953,6 +981,15 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( ShapeAdaptor permsShape = operands.getShape(1); auto inputType = getElementTypeOrSelf(operands[0]); + if (auto quantType = + llvm::dyn_cast(inputType)) + inputType = quantType.getStorageType(); + + if (auto quantType = + llvm::dyn_cast(inputType)) + inputType = quantType.getStorageType(); + + // If input rank and permutation length is unknown, the output rank is // unknown. if (!inputShape.hasRank() || !permsShape.hasRank() || @@ -1161,7 +1198,6 @@ REDUCE_SHAPE_INFER(tosa::ReduceProdOp) REDUCE_SHAPE_INFER(tosa::ReduceSumOp) #undef REDUCE_SHAPE_INFER COMPATIBLE_RETURN_TYPES(tosa::ConcatOp) -COMPATIBLE_RETURN_TYPES(tosa::TransposeOp) #undef COMPATIBLE_RETURN_TYPES static LogicalResult NAryInferReturnTypes( @@ -1369,7 +1405,7 @@ LogicalResult tosa::SelectOp::verify() { llvm::cast(getOperand(1).getType()).getElementType(); auto input3ETy = llvm::cast(getOperand(2).getType()).getElementType(); - auto resultETy = getElementTypeOrSelf(getResult()); + auto resultETy = getElementTypeOrSelf(getResult()); auto result1 = verifyForSameRank(*this, input1ShapeType, input2ShapeType); if (result1.failed()) { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index a255e795121b4..901dcf78380bf 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -193,6 +193,14 @@ func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () { return } +// CHECK-LABEL: @max_pool_i32 +func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () { + // CHECK: arith.constant -2147483648 + // CHECK: linalg.pooling_nhwc_max + %0 = "tosa.max_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>) + return +} + // ----- // CHECK-LABEL: @avg_pool_f32 diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index b320f35aab87b..7a83fbd6436d7 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1415,11 +1415,11 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor) -> () { // CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> // CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> ()> +// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0)> // CHECK-LABEL: @select_fp32 -func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor) -> tensor<1x12x5x5xf32> { +func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor<1x1x1x1xf32>) -> tensor<1x12x5x5xf32> { // CHECK: linalg.generic - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x5x5xf32> return %0 : tensor<1x12x5x5xf32> } diff --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir index d57b5cbcf475c..4552f4e7180db 100644 --- a/mlir/test/Dialect/Tosa/inlining.mlir +++ b/mlir/test/Dialect/Tosa/inlining.mlir @@ -39,16 +39,16 @@ func.func @inlined_while_fn(%arg0: tensor, %arg1: tensor, %arg2: tenso %2 = call @while_cond_40(%arg4, %arg5, %arg6, %arg7) : (tensor, tensor, tensor, tensor<10xi32>) -> tensor "tosa.yield"(%2) : (tensor) -> () }, { - ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor<10xi32>): - %2:4 = call @while_body_50(%arg4, %arg5, %arg6, %arg7) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) - "tosa.yield"(%2#0, %2#1, %2#2, %2#3) : (tensor, tensor, tensor, tensor<10xi32>) -> () + ^bb0(%arg4: tensor<1xi32>, %arg5: tensor<1xi32>, %arg6: tensor<1xi32>, %arg7: tensor<10xi32>): + %2:4 = call @while_body_50(%arg4, %arg5, %arg6, %arg7) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<10xi32>) -> (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<10xi32>) + "tosa.yield"(%2#0, %2#1, %2#2, %2#3) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<10xi32>) -> () }) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) return %1#3 : tensor<10xi32> } -func.func private @while_body_50(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) { - %1 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor) -> tensor<10xi32> - return %1, %arg1, %arg2, %2: tensor, tensor, tensor, tensor<10xi32> +func.func private @while_body_50(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<10xi32>) -> (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<10xi32>) { + %1 = "tosa.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32> + return %1, %arg1, %arg2, %2: tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<10xi32> } func.func private @while_cond_40(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> tensor { %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor, tensor) -> tensor diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index bf913363039d7..2c650062b70a4 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -8,7 +8,7 @@ func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } -// ----- +// ----- // CHECK-LABEL: @test_multiple func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor) -> tensor<*xf32> { @@ -104,33 +104,33 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () { // ----- // CHECK-LABEL: @test_binary_scalar_f32 -func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor) -> () { - // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> +func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () { + // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32}> : (tensor<4xf32>, tensor) -> tensor<4xf32> - %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xf32>, tensor) -> tensor<*xf32> + // CHECK: "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32}> : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> - %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> - %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> - // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> - %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> + // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> - // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> - %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> + // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> return } @@ -172,48 +172,48 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32 // ----- // CHECK-LABEL: @test_binary_i32 -func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor) -> () { - // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> +func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<1xi32>) -> () { + // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: "tosa.bitwise_and"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %1 = "tosa.bitwise_and"(%arg0, %arg1): (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: "tosa.bitwise_and"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %1 = "tosa.bitwise_and"(%arg0, %arg1): (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: "tosa.bitwise_or"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %2 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: "tosa.bitwise_or"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %2 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %3 = "tosa.bitwise_xor"(%arg0, %arg1): (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %3 = "tosa.bitwise_xor"(%arg0, %arg1): (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> - %4 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1> + %4 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> - // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> - %5 = "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1> + %5 = "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> - // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> - %6 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1> + %6 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> - // CHECK: "tosa.logical_left_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %7 = "tosa.logical_left_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: "tosa.logical_left_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %7 = "tosa.logical_left_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: "tosa.logical_right_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %8 = "tosa.logical_right_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: "tosa.logical_right_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %8 = "tosa.logical_right_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %9 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %9 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %10 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %10 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32}> : (tensor<4xi32>, tensor) -> tensor<4xi32> - %11 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32}> : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %11 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %12 = "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %12 = "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> - %13 = "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %13 = "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> return } @@ -237,9 +237,9 @@ func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor) -> () { // ----- // CHECK-LABEL: @test_select_i32 -func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor, %arg2 : tensor<4xi32>) -> () { - // CHECK: "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<4xi32> - %0 = "tosa.select"(%arg0, %arg1, %arg2): (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<*xi32> +func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<1xi32>, %arg2 : tensor<4xi32>) -> () { + // CHECK: "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor<1xi32>, tensor<4xi32>) -> tensor<4xi32> + %0 = "tosa.select"(%arg0, %arg1, %arg2): (tensor<4xi1>, tensor<1xi32>, tensor<4xi32>) -> tensor<*xi32> return } @@ -703,11 +703,11 @@ func.func @test_pool_dynamic_input(%arg0: tensor) { // CHECK-LABEL: @test_pool_padded func.func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) { - // CHECK: -> tensor<3x5x11x7xf32> - %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor + // CHECK: -> tensor<3x5x8x7xf32> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor - // CHECK: -> tensor<3x5x11x7xf32> - %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor + // CHECK: -> tensor<3x5x8x7xf32> + %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor return } @@ -733,11 +733,11 @@ func.func @conv2d_dynamic_bias(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3 // CHECK-LABEL: @test_pool_stride func.func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) { - // CHECK: -> tensor<3x4x4x7xf32> - %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor + // CHECK: -> tensor<3x5x4x7xf32> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor - // CHECK: -> tensor<3x4x4x7xf32> - %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor + // CHECK: -> tensor<3x5x4x7xf32> + %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor return } From 351c39aec26fb68901042a1f5052014cd4ebe2a6 Mon Sep 17 00:00:00 2001 From: chaitany Date: Tue, 23 Jan 2024 15:19:05 +0530 Subject: [PATCH 08/12] refactor: removing add,select and greater_equal verifiers --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 9 +-- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 74 -------------------- 2 files changed, 3 insertions(+), 80 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index a6b85c0ba2c29..7808ea7da5fc7 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -499,8 +499,7 @@ def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> { Tosa_Tensor:$output ); - let hasFolder = 1; - let hasVerifier = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1126,8 +1125,7 @@ def Tosa_SelectOp : Tosa_Op<"select", [ Tosa_Tensor:$output ); let hasCanonicalizeMethod = 1; - let hasFolder = 1; - let hasVerifier = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1211,8 +1209,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ I1Tensor:$output ); - let hasFolder = 1; - let hasVerifier = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 72975f4aefdb2..040eac24c85d0 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1359,80 +1359,6 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( return success(); } -template static LogicalResult verifyBinaryOpWithEqualRank(T op) { - auto input1ShapeType = llvm::cast(op.getInput1().getType()); - auto input2ShapeType = llvm::cast(op.getInput2().getType()); - - if (input1ShapeType.hasRank() && input2ShapeType.hasRank()) { - auto input1Rank = input1ShapeType.getRank(); - auto input2Rank = input2ShapeType.getRank(); - if (input1Rank != input2Rank) { - return op.emitOpError("both operands must have same rank."); - } - } - return success(); -} - -LogicalResult tosa::AddOp::verify() { - return verifyBinaryOpWithEqualRank(*this); -} - -LogicalResult tosa::GreaterEqualOp::verify() { - return verifyBinaryOpWithEqualRank(*this); -} - -template -LogicalResult verifyForSameRank(T op, ShapedType inputShape1, - ShapedType inputShape2) { - if (inputShape1.hasRank() && inputShape2.hasRank()) { - auto input1Rank = inputShape1.getRank(); - auto input2Rank = inputShape2.getRank(); - if (input1Rank != input2Rank) { - return op.emitOpError("both operands must have same rank."); - } - } - return success(); -} - -LogicalResult tosa::SelectOp::verify() { - - auto input1ShapeType = llvm::cast(getOperand(0).getType()); - auto input2ShapeType = llvm::cast(getOperand(1).getType()); - auto input3ShapeType = llvm::cast(getOperand(2).getType()); - auto outputShapeType = llvm::cast(getResult().getType()); - - auto input2ETy = - llvm::cast(getOperand(1).getType()).getElementType(); - auto input3ETy = - llvm::cast(getOperand(2).getType()).getElementType(); - auto resultETy = getElementTypeOrSelf(getResult()); - - auto result1 = verifyForSameRank(*this, input1ShapeType, input2ShapeType); - if (result1.failed()) { - return result1; - } - auto result2 = verifyForSameRank(*this, input1ShapeType, input3ShapeType); - if (result2.failed()) { - return result2; - } - auto result3 = verifyForSameRank(*this, input1ShapeType, outputShapeType); - if (result3.failed()) { - return result3; - } - if (input2ETy != input3ETy) { - return emitOpError("inputs should be of same type."); - } - if ((input2ETy != resultETy) || (input3ETy != resultETy)) { - return emitOpError("inputs and result should be of same type."); - } - - auto result = OpTrait::impl::verifyCompatibleOperandBroadcast(getOperation()); - if (result.failed()) { - return result; - } - return success(); -} - LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); } LogicalResult Conv3DOp::inferReturnTypeComponents( From 56666789ee25222f591d0591ed41b63fdaeab81f Mon Sep 17 00:00:00 2001 From: chaitany Date: Tue, 23 Jan 2024 17:04:29 +0530 Subject: [PATCH 09/12] refactor: removing the testcases covering the removed verifiers --- mlir/test/Dialect/Tosa/invalid.mlir | 128 ---------------------------- 1 file changed, 128 deletions(-) diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 87afad3eae66f..4a6d7576b38a0 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -234,30 +234,6 @@ func.func @test_max_pool2d_output_width_incorrect(%arg0: tensor<1x6x6x9xi8>) -> // ----- -func.func @test_add_incompabitble_type(%arg0: tensor<13x21xf32>, %arg1: tensor<13x21xi8>) -> tensor<13x21xf32> { - // expected-error@+1 {{'tosa.add' op requires the same element type for all operands and results}} - %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21xf32>, tensor<13x21xi8>) -> tensor<13x21xf32> - return %0 : tensor<13x21xf32> -} - -// ----- - -func.func @test_add_incorrect_output(%arg0: tensor<13x21xf32>, %arg1: tensor<13x21xf32>) -> tensor<13x2xf32> { - // expected-error@+1 {{'tosa.add' op result type '13x2' not broadcast compatible with broadcasted operands's shapes '13x21'}} - %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21xf32>, tensor<13x21xf32>) -> tensor<13x2xf32> - return %0 : tensor<13x2xf32> -} - -// ----- - -func.func @test_add_incorrect_output2(%arg0: tensor<13x21xf32>, %arg1: tensor<2x13x21xf32>) -> tensor<2x13x21xf32> { - // expected-error@+1 {{'tosa.add' op both operands must have same rank.}} - %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21xf32>, tensor<2x13x21xf32>) -> tensor<2x13x21xf32> - return %0 : tensor<2x13x21xf32> -} - -// ----- - func.func @test_const_incorrect_output(%arg0 : index) -> tensor<4xi32> { // expected-error@+1{{inferred shape of elements literal ([4]) does not match type ([3])}} %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<3xi32>} : () -> tensor<4xi32> @@ -266,110 +242,6 @@ func.func @test_const_incorrect_output(%arg0 : index) -> tensor<4xi32> { // ----- -func.func @test_greater_equal_incompatible(%arg0: tensor<13x1x3x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { - // expected-error@+1{{'tosa.greater_equal' op operands don't have broadcast-compatible shapes}} - %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x1x3x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> - return %0 : tensor<13x21x3xi1> -} - -// ----- - -func.func @test_greater_equal_unequal_rank(%arg0: tensor<12x13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor { - // expected-error@+1{{'tosa.greater_equal' op both operands must have same rank.}} - %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<12x13x21x3xf32>, tensor<13x21x3xf32>) -> tensor - return %0 : tensor -} - -// ----- - -func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { - // expected-error@+1{{'tosa.greater_equal' op result #0 must be tensor of 1-bit signless integer values, but got 'tensor<13x21x3xf32>'}} - %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> - return %0 : tensor<13x21x3xf32> -} - -// ----- - -func.func @test_add_unequal_rank(%arg0: tensor<3x4x8400xf32>, %arg1: tensor<8400xf32>) -> tensor<3x4x8400xf32> { - // expected-error@+1{{'tosa.add' op both operands must have same rank.}} - %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4x8400xf32>, tensor<8400xf32>) -> tensor<3x4x8400xf32> - return %0 : tensor<3x4x8400xf32> -} - -// ----- - -func.func @test_select_unequal_rank_inputs(%arg0: tensor<2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { - // expected-error@+1{{'tosa.select' op both operands must have same rank.}} - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> - return %0 : tensor<3x2xf32> -} - -// ----- - -func.func @test_select_unequal_rank_inputs2(%arg0: tensor<1x2xi1>, %arg1: tensor<1x3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { - // expected-error@+1{{'tosa.select' op both operands must have same rank.}} - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<1x3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> - return %0 : tensor<3x2xf32> -} - -// ----- - -func.func @test_select_unequal_rank_inputs3(%arg0: tensor<1x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<1x3x2xf32>) -> tensor<3x2xf32> { - // expected-error@+1{{'tosa.select' op both operands must have same rank.}} - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<3x2xf32>, tensor<1x3x2xf32>) -> tensor<3x2xf32> - return %0 : tensor<3x2xf32> -} - -// ----- - -func.func @test_select_not_boardcastable_arg1(%arg0: tensor<2x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { - // expected-error@+1{{'tosa.select' op operands don't have broadcast-compatible shapes}} - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<2x2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> - return %0 : tensor<3x2xf32> -} - -// ----- - -func.func @test_select_not_boardcastable_result(%arg0: tensor<1x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<4x2xf32> { - // expected-error@+1{{'tosa.select' op result type '4x2' not broadcast compatible with broadcasted operands's shapes '3x2'}} - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<4x2xf32> - return %0 : tensor<4x2xf32> -} - -// ----- - -func.func @test_select_not_boardcastable_arg3(%arg0: tensor<1x2xi1>, %arg1: tensor<2x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { - // expected-error@+1{{'tosa.select' op operands don't have broadcast-compatible shapes}} - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<2x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> - return %0 : tensor<3x2xf32> -} - -// ----- - -func.func @test_select_incompatible_1(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xi8> { - // expected-error@+1{{'tosa.select' op inputs and result should be of same type.}} - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi8> - return %0 : tensor<13x21x3xi8> -} - -// ----- - -func.func @test_select_incompatible_2(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xi8>) -> tensor<13x21x3xf32> { - // expected-error@+1{{'tosa.select' op inputs should be of same type.}} - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xi8>) -> tensor<13x21x3xf32> - return %0 : tensor<13x21x3xf32> -} - -// ----- - -func.func @test_select_incompatible_3(%arg0: tensor<1x1x1xi8>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xi8>) -> tensor<13x21x3xf32> { - // expected-error@+1{{'tosa.select' op operand #0 must be tensor of 1-bit signless integer values, but got 'tensor<1x1x1xi8>'}} - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi8>, tensor<13x21x3xf32>, tensor<13x21x3xi8>) -> tensor<13x21x3xf32> - return %0 : tensor<13x21x3xf32> -} - -// ----- - func.func @test_transpose_incorrect_result_shape(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x20xf32> { %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> // expected-error@+2{{'tosa.transpose' op failed to infer returned types}} From b6bd43fedd00bb91cfcbf9201d6c36524974b397 Mon Sep 17 00:00:00 2001 From: chaitany Date: Tue, 23 Jan 2024 17:07:43 +0530 Subject: [PATCH 10/12] refactor: removing the testcases covering the removed verifiers --- mlir/test/Dialect/Tosa/ops.mlir | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 25c1ea2fed400..2ac0633d2d918 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -324,13 +324,6 @@ func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg return %0 : tensor<13x21x3xf32> } -// ----- -// CHECK-LABEL: select -func.func @test_select_broardcastable(%arg0: tensor<1x2xi1>, %arg1: tensor<3x2xf32>, %arg2: tensor<3x2xf32>) -> tensor<3x2xf32> { - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x2xi1>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> - return %0 : tensor<3x2xf32> -} - // ----- // CHECK-LABEL: equal func.func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> { From 5561ffef55a122e3999d161663614c3aac06ce1d Mon Sep 17 00:00:00 2001 From: chaitany Date: Tue, 23 Jan 2024 17:09:04 +0530 Subject: [PATCH 11/12] refactor: formatting --- mlir/test/Dialect/Tosa/ops.mlir | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 2ac0633d2d918..72f020336ff05 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -324,6 +324,7 @@ func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg return %0 : tensor<13x21x3xf32> } + // ----- // CHECK-LABEL: equal func.func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> { From 549a6c5c97b8c2a826da7aae86fc857795822c65 Mon Sep 17 00:00:00 2001 From: chaitany Date: Tue, 23 Jan 2024 12:34:35 +0000 Subject: [PATCH 12/12] refactor: reverting changes to lit tests --- .../TosaToLinalg/tosa-to-linalg.mlir | 6 +- mlir/test/Dialect/Tosa/inlining.mlir | 14 +-- mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 104 +++++++++--------- 3 files changed, 62 insertions(+), 62 deletions(-) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 7a83fbd6436d7..b320f35aab87b 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1415,11 +1415,11 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor) -> () { // CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> // CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> ()> // CHECK-LABEL: @select_fp32 -func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor<1x1x1x1xf32>) -> tensor<1x12x5x5xf32> { +func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor) -> tensor<1x12x5x5xf32> { // CHECK: linalg.generic - %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x5x5xf32> + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> return %0 : tensor<1x12x5x5xf32> } diff --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir index 4552f4e7180db..d57b5cbcf475c 100644 --- a/mlir/test/Dialect/Tosa/inlining.mlir +++ b/mlir/test/Dialect/Tosa/inlining.mlir @@ -39,16 +39,16 @@ func.func @inlined_while_fn(%arg0: tensor, %arg1: tensor, %arg2: tenso %2 = call @while_cond_40(%arg4, %arg5, %arg6, %arg7) : (tensor, tensor, tensor, tensor<10xi32>) -> tensor "tosa.yield"(%2) : (tensor) -> () }, { - ^bb0(%arg4: tensor<1xi32>, %arg5: tensor<1xi32>, %arg6: tensor<1xi32>, %arg7: tensor<10xi32>): - %2:4 = call @while_body_50(%arg4, %arg5, %arg6, %arg7) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<10xi32>) -> (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<10xi32>) - "tosa.yield"(%2#0, %2#1, %2#2, %2#3) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<10xi32>) -> () + ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor<10xi32>): + %2:4 = call @while_body_50(%arg4, %arg5, %arg6, %arg7) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) + "tosa.yield"(%2#0, %2#1, %2#2, %2#3) : (tensor, tensor, tensor, tensor<10xi32>) -> () }) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) return %1#3 : tensor<10xi32> } -func.func private @while_body_50(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<10xi32>) -> (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<10xi32>) { - %1 = "tosa.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32> - return %1, %arg1, %arg2, %2: tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<10xi32> +func.func private @while_body_50(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) { + %1 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor) -> tensor<10xi32> + return %1, %arg1, %arg2, %2: tensor, tensor, tensor, tensor<10xi32> } func.func private @while_cond_40(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> tensor { %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor, tensor) -> tensor diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 2c650062b70a4..7f32b7db76258 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -8,7 +8,7 @@ func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } -// ----- +// ----- // CHECK-LABEL: @test_multiple func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor) -> tensor<*xf32> { @@ -104,33 +104,33 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () { // ----- // CHECK-LABEL: @test_binary_scalar_f32 -func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () { - // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> +func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor) -> () { + // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> - // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> - // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> - // CHECK: "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32}> : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + // CHECK: "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32}> : (tensor<4xf32>, tensor) -> tensor<4xf32> + %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xf32>, tensor) -> tensor<*xf32> - // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> - // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> - // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> + %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> - // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> + %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> - // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> + %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> return } @@ -172,48 +172,48 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32 // ----- // CHECK-LABEL: @test_binary_i32 -func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<1xi32>) -> () { - // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> +func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor) -> () { + // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> - // CHECK: "tosa.bitwise_and"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %1 = "tosa.bitwise_and"(%arg0, %arg1): (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: "tosa.bitwise_and"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %1 = "tosa.bitwise_and"(%arg0, %arg1): (tensor<4xi32>, tensor) -> tensor<*xi32> - // CHECK: "tosa.bitwise_or"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %2 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: "tosa.bitwise_or"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %2 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> - // CHECK: "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %3 = "tosa.bitwise_xor"(%arg0, %arg1): (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %3 = "tosa.bitwise_xor"(%arg0, %arg1): (tensor<4xi32>, tensor) -> tensor<*xi32> - // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1> - %4 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> + %4 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> - // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1> - %5 = "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> + // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> + %5 = "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> - // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1> - %6 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> + // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> + %6 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> - // CHECK: "tosa.logical_left_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %7 = "tosa.logical_left_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: "tosa.logical_left_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %7 = "tosa.logical_left_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> - // CHECK: "tosa.logical_right_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %8 = "tosa.logical_right_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: "tosa.logical_right_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %8 = "tosa.logical_right_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> - // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %9 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %9 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> - // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %10 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %10 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> - // CHECK: "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32}> : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %11 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: "tosa.mul"(%arg0, %arg1) <{shift = 0 : i32}> : (tensor<4xi32>, tensor) -> tensor<4xi32> + %11 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> - // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %12 = "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %12 = "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> - // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %13 = "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %13 = "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> return } @@ -237,9 +237,9 @@ func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor) -> () { // ----- // CHECK-LABEL: @test_select_i32 -func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<1xi32>, %arg2 : tensor<4xi32>) -> () { - // CHECK: "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor<1xi32>, tensor<4xi32>) -> tensor<4xi32> - %0 = "tosa.select"(%arg0, %arg1, %arg2): (tensor<4xi1>, tensor<1xi32>, tensor<4xi32>) -> tensor<*xi32> +func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor, %arg2 : tensor<4xi32>) -> () { + // CHECK: "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<4xi32> + %0 = "tosa.select"(%arg0, %arg1, %arg2): (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<*xi32> return }