From b0f576c51b4782801af1da7655de1616af22e5b7 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 29 May 2025 13:59:05 +0000 Subject: [PATCH] [mlir][tosa] Fix MulOp verifier handling for unranked operands The previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of `rankedOperandTypes` would be >= 2, which isn't the case when both a and b are unranked. This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well. Change-Id: I0d0b7f7e8058f9a25dcb6c051aa0375cf780b80c --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 85 +++++++++++----------------- mlir/test/Dialect/Tosa/invalid.mlir | 29 +++++++++- mlir/test/Dialect/Tosa/ops.mlir | 16 ++++++ 3 files changed, 77 insertions(+), 53 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 3ee5a85a21dca..298802fc7fa6c 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1779,7 +1779,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents( } LogicalResult tosa::MulOp::verify() { - auto resElemType = getElementTypeOrSelf(getOutput()); + const Value output = getOutput(); + auto resElemType = getElementTypeOrSelf(output); // Verify if the element type among operands and result match tosa // specification. @@ -1819,59 +1820,39 @@ LogicalResult tosa::MulOp::verify() { // Verify the op has same ranks for all main operands (excludes extra operands // such as shift of mul op, so this is the only difference with the built-in // `SameOperandsAndResultRank` trait) and results types, if known. - - // delegate function that returns true if type is a shaped type with known - // rank - auto hasRank = [](const Type type) { - if (auto shaped_type = dyn_cast(type)) - return shaped_type.hasRank(); - - return false; - }; - - auto rankedOperandTypes = - llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank)); - - auto rankedResultTypes = - llvm::make_filter_range(getOperation()->getResultTypes(), hasRank); - - // If all operands and results are unranked, then no further verification. - if (rankedOperandTypes.empty() && rankedResultTypes.empty()) + TypeRange operandTypes = getOperandTypes(); + ShapedType aType = cast(operandTypes[0]); + ShapedType bType = cast(operandTypes[1]); + + const bool aHasRank = aType.hasRank(); + const bool bHasRank = bType.hasRank(); + if (aHasRank && bHasRank) { + const int64_t aRank = aType.getRank(); + const int64_t bRank = bType.getRank(); + if (aRank != bRank) + return emitOpError("a and b operands don't have matching ranks, got ") + << aRank << " and " << bRank; + + // check for broadcast compatible shapes + SmallVector resultShape; + if (!mlir::OpTrait::util::getBroadcastedShape( + aType.getShape(), bType.getShape(), resultShape)) + return emitOpError("a and b operands don't have broadcast-compatible " + "shapes, got ") + << aType << " and " << bType; + } + + ShapedType resultType = cast(output.getType()); + if (!resultType.hasRank()) return success(); - // delegate function that returns rank of shaped type with known rank - auto getRank = [](const Type type) { - return cast(type).getRank(); - }; - - auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin()) - : getRank(*rankedResultTypes.begin()); - - for (size_t i = 0; i < 2; ++i) { - if (rank != getRank(rankedOperandTypes[i])) { - return emitOpError("operands don't have matching ranks"); - } - } - - for (const auto type : rankedResultTypes) { - if (rank != getRank(type)) { - return emitOpError("result type has different rank than operands"); - } - } - - // check for broadcast compatible shapes in first two operands (ignoring - // shift) - - // delegate function that returns shape of shaped type - auto getShape = [](const Type type) { - return mlir::cast(type).getShape(); - }; - SmallVector resultShape; - if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]), - getShape(rankedOperandTypes[1]), - resultShape)) { - return emitOpError("operands don't have broadcast-compatible shapes"); - } + const int64_t resultRank = resultType.getRank(); + if (aHasRank && resultRank != aType.getRank()) + return emitOpError("result type has different rank than a, got ") + << resultRank << " vs " << aType.getRank(); + if (bHasRank && resultRank != bType.getRank()) + return emitOpError("result type has different rank than b, got ") + << resultRank << " vs " << bType.getRank(); return success(); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 7b589fa839b44..3298e518de2f5 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1107,11 +1107,38 @@ func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tenso // CHECK-LABEL: test_mul_non_broadcast func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> { %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}} + // expected-error@+1 {{'tosa.mul' op a and b operands don't have broadcast-compatible shapes, got 'tensor<13x21x2xf32>' and 'tensor<3x1x3xf32>'}} %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } +// ----- +// CHECK-LABEL: test_mul_different_operand_ranks +func.func @test_mul_different_operand_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op a and b operands don't have matching ranks, got 2 and 3}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: test_mul_different_a_and_result_ranks +func.func @test_mul_different_a_and_result_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op result type has different rank than a, got 3 vs 2}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: test_mul_different_b_and_result_ranks +func.func @test_mul_different_b_and_result_ranks(%arg0: tensor<*xf32>, %arg1: tensor<13x12xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op result type has different rank than b, got 3 vs 2}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<13x12xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + // ----- // CHECK-LABEL: test_resize_invalid_scale_values func.func @test_resize_invalid_scale_values(%arg0: tensor<1x8x8x8xf32>) -> tensor { diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 5ec506a45b3ad..882b59d029a4a 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -424,6 +424,22 @@ func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tenso return %0 : tensor<13x21x3xi16> } +// ----- +// CHECK-LABEL: test_mul_unranked_b +func.func @test_mul_unranked_b(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: test_mul_unranked_a_and_b +func.func @test_mul_unranked_a_and_b(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + // ----- // CHECK-LABEL: pow func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {