diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index 195a58432737b..f4823858e3893 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -40,7 +40,7 @@ void addTosaToLinalgPasses( // Note: Default to 'none' level unless otherwise specified. std::optional validationOptions = tosa::TosaValidationOptions{ - {"none"}, {"none"}, false, tosa::TosaLevelEnum::None}); + {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None}); /// Populates TOSA to linalg pipelines /// Currently, this includes only the "tosa-to-linalg-pipeline". diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index 1df1761d38455..d73b288d2c8bf 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -115,6 +115,7 @@ class TosaProfileCompliance { // environment. LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv); LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv); + LogicalResult checkInvalid(Operation *op); template LogicalResult checkProfileOrExtension( @@ -163,6 +164,10 @@ class TosaProfileCompliance { stringifyProfile(const SmallVector> &profileSet); private: + template + FailureOr> getOperatorDefinition(Operation *op, + CheckCondition &condition); + OperationProfileComplianceMap profileComplianceMap; OperationExtensionComplianceMap extensionComplianceMap; }; diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index f6ead2b6ba3dd..2d5b0b39df078 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -94,6 +94,11 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool", /*default=*/"false", "Verify if the properties of certain operations align the spec requirement">, + Option<"allowInvalidOpDatatypeCombinations", "allow-invalid-op-datatype-combinations", "bool", + /*default=*/"false", + "Disable checks for operations that are determined to be invalid due to their " + "operand/result datatypes not aligning with the 'Supported Data Types' " + "sections of the specifciation">, Option<"level", "level", "mlir::tosa::TosaLevelEnum", /*default=*/"mlir::tosa::TosaLevelEnum::EightK", "Validate if operator parameters are within specfication for the given level", diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index bfadebba12708..4cf232a7bc767 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -119,6 +119,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() { validationOptions.profile = {"none"}; validationOptions.extension = {"none"}; validationOptions.strictOpSpecAlignment = false; + validationOptions.allowInvalidOpDatatypeCombinations = false; validationOptions.level = tosa::TosaLevelEnum::EightK; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, tosaToLinalgNamedOptions, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 4aeb095ffff07..eb7981b313d1d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -140,6 +140,7 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) { template <> LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) { addValue(op.getValues()); + addValue(op.getIndices()); addValue(op.getOutput()); return success(); } @@ -147,6 +148,7 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) { template <> LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) { addValue(op.getValuesIn()); + addValue(op.getIndices()); addValue(op.getInput()); addValue(op.getValuesOut()); return success(); @@ -347,6 +349,19 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // Tosa Profile And Extension Compliance Checker //===----------------------------------------------------------------------===// +template +FailureOr> +TosaProfileCompliance::getOperatorDefinition(Operation *op, + CheckCondition &condition) { + const std::string opName = op->getName().getStringRef().str(); + const auto complianceMap = getProfileComplianceMap(); + const auto it = complianceMap.find(opName); + if (it == complianceMap.end()) + return {}; + + return findMatchedProfile(op, it->second, condition); +} + template LogicalResult TosaProfileCompliance::checkProfileOrExtension( Operation *op, const tosa::TargetEnv &targetEnv, @@ -356,11 +371,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( if (specRequiredModeSet.size() == 0) return success(); - auto opName = op->getName().getStringRef().str(); - auto compMap = getProfileComplianceMap(); - auto it = compMap.find(opName); - - if (it == compMap.end()) { + CheckCondition condition = CheckCondition::invalid; + const auto maybeOpRequiredMode = getOperatorDefinition(op, condition); + if (failed(maybeOpRequiredMode)) { // Operators such as control-flow and shape ops do not have an operand type // restriction. When the profile compliance information of operation is not // found, confirm if the target have enabled the profile required from the @@ -381,12 +394,9 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( return failure(); } - CheckCondition condition = CheckCondition::invalid; - // Find the profiles or extensions requirement according to the signature of - // type of the operand list. - SmallVector opRequiredMode = - findMatchedProfile(op, it->second, condition); - + // Find the required profiles or extensions according to the operand type + // combination. + const auto opRequiredMode = maybeOpRequiredMode.value(); if (opRequiredMode.size() == 0) { // No matched restriction found. return success(); @@ -466,6 +476,17 @@ TosaProfileCompliance::checkExtension(Operation *op, return success(); } +LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { + CheckCondition condition = CheckCondition::invalid; + const auto maybeProfDef = getOperatorDefinition(op, condition); + const auto maybeExtDef = getOperatorDefinition(op, condition); + if (!failed(maybeProfDef) && !failed(maybeExtDef) && + !maybeProfDef.value().size() && !maybeExtDef.value().size()) + return failure(); + + return success(); +} + // Find the profiles or extensions requirement according to the signature of // type of the operand list. template @@ -483,7 +504,6 @@ SmallVector TosaProfileCompliance::findMatchedProfile( for (size_t i = 0; i < compInfo.size(); i++) { SmallVector> sets = compInfo[i].operandTypeInfoSet; - for (SmallVector expected : sets) { assert(present.size() == expected.size() && "the entries for profile-based compliance do not match between " diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 79c13793d7713..3ec7354562d23 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -165,6 +165,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { this->profile = options.profile; this->extension = options.extension; this->strictOpSpecAlignment = options.strictOpSpecAlignment; + this->allowInvalidOpDatatypeCombinations = + options.allowInvalidOpDatatypeCombinations; this->level = options.level; } void runOnOperation() final; @@ -1042,6 +1044,12 @@ void TosaValidation::runOnOperation() { } } + if (!allowInvalidOpDatatypeCombinations && + failed(profileComp.checkInvalid(op))) { + op->emitOpError("illegal: operand/result data types not supported"); + return signalPassFailure(); + } + // Some uses of TOSA rely on the constant operands of particular // operations. if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op))) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir index ecd5c792e08b6..731e134ed1a07 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir @@ -4,20 +4,20 @@ // ----- // check that -tosa-validate of stateful ops kick in -func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> +func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> // expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}} - tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32> + tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8> return } // ----- // check that -tosa-validate level checking kick in -func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> { +func.func @tensor_with_unknown_rank(%arg0: tensor<*xi32>) -> tensor<*xi32> { // expected-error@+1 {{'tosa.abs' op failed level check: unranked tensor}} - %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8> - return %0 : tensor<*xi8> + %0 = "tosa.abs"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> } // ----- diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir index fd9b3d5f23483..0ec46022157d7 100644 --- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir +++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir @@ -2,13 +2,13 @@ // Check operations when the dynamic extension is enabled. //-------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations" // ----- -func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> { - %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8> - return %0 : tensor<13x21x3xi8> +func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> { + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16> + return %0 : tensor<13x21x3xi16> } // ----- diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 3203c64b439da..ac8a247da24a7 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -616,17 +616,17 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten // ----- -func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> +func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> // expected-error@+1 {{'tosa.variable' op name has already been declared}} - tosa.variable @stored_var = dense<3> : tensor<1x4x8xi32> + tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8> return } // ----- -func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> +func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> // expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}} %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16> return @@ -634,8 +634,8 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () { // ----- -func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> +func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> // expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}} %0 = tosa.variable.read @stored_var : tensor<1x4x8xi32> return @@ -644,7 +644,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () { // ----- func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> // expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}} tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16> return @@ -652,10 +652,10 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () { // ----- -func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> +func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> // expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}} - tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32> + tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi8> return } @@ -1921,3 +1921,21 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor { %1 = tosa.transpose %arg0 {perms = array} : (tensor<*xf32>) -> tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: test_add_i1 +func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + // expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}} + %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- + +// CHECK-LABEL: test_mul_out_i16 +func.func @test_mul_out_i16(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi16> { + // expected-error@+1 {{'tosa.mul' op illegal: operand/result data types not supported}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi16> + return %0 : tensor<13x21x3xi16> +} diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index bde5b5ec7cffe..d1594232e4e1d 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -191,10 +191,10 @@ func.func @test_matmul_non_const_b_zp(%arg0: tensor<1x14x19xf32>, %arg1: tensor< // ----- -func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi8> { +func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8>, %shift: tensor<1xi8>) -> tensor<13x21x3xi32> { // expected-error@+1 {{'tosa.mul' op expected compile time resolvable constant, but got variable value for operand #2}} - %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi8> - return %0 : tensor<13x21x3xi8> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi8>, tensor<13x1x3xi8>, tensor<1xi8>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> } // ----- diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index bdf18ec823128..0f469761d89e3 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -169,10 +169,10 @@ func.func @test_sub_rank_invalid(%arg0: tensor<1x1x1x1x1x21x3xf32>, %arg1: tenso // ----- -func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi32>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> { +func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> { // expected-error@+1 {{'tosa.table' op failed level check: operand rank(shape) <= MAX_RANK}} - %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi32>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi16> - return %0 : tensor<1x1x1x1x1x1x64xi16> + %0 = tosa.table %arg0, %arg1 : (tensor<1x1x1x1x1x1x64xi16>, tensor<513xi16>) -> tensor<1x1x1x1x1x1x64xi32> + return %0 : tensor<1x1x1x1x1x1x64xi32> } // -----