diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index c84e4f17c38d8..f5ae71cad1524 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -39,8 +39,8 @@ void addTosaToLinalgPasses( TosaToLinalgNamedOptions(), // Note: Default to 'none' level unless otherwise specified. std::optional validationOptions = - tosa::TosaValidationOptions{tosa::TosaProfileEnum::Undefined, false, - tosa::TosaLevelEnum::None}); + tosa::TosaValidationOptions{ + {"none"}, false, tosa::TosaLevelEnum::None}); /// Populates TOSA to linalg pipelines /// Currently, this includes only the "tosa-to-linalg-pipeline". diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index c0352fa88fe08..dac67633769c7 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -76,7 +76,7 @@ def TosaProfileType : I32EnumAttr<"TosaProfileEnum", "Tosa profile", I32EnumAttrCase<"BaseInference", 0, "bi">, I32EnumAttrCase<"MainInference", 1, "mi">, I32EnumAttrCase<"MainTraining", 2, "mt">, - I32EnumAttrCase<"Undefined", 3> + I32EnumAttrCase<"Undefined", 3, "none"> ]>{ let cppNamespace = "mlir::tosa"; } @@ -97,19 +97,8 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { }]; let options = [ - Option<"profile", "profile", "mlir::tosa::TosaProfileEnum", - /*default=*/"mlir::tosa::TosaProfileEnum::Undefined", - "Validate if operations match for the given profile", - [{::llvm::cl::values( - clEnumValN(mlir::tosa::TosaProfileEnum::BaseInference, "bi", - "Use Base Inference profile."), - clEnumValN(mlir::tosa::TosaProfileEnum::MainInference, "mi", - "Use Main Inference profile."), - clEnumValN(mlir::tosa::TosaProfileEnum::MainTraining, "mt", - "Use Main Training profile."), - clEnumValN(mlir::tosa::TosaProfileEnum::Undefined, "undefined", - "Do not define a profile.") - )}]>, + ListOption<"profile", "profile", "std::string", + "Validate if operations match for the given profile set">, Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool", /*default=*/"false", "Verify if the properties of certain operations align the spec requirement">, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index 44036d7c31a91..06a7262c46742 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -115,7 +115,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() { TosaToLinalgOptions tosaToLinalgOptions; TosaToLinalgNamedOptions tosaToLinalgNamedOptions; TosaValidationOptions validationOptions; - validationOptions.profile = tosa::TosaProfileEnum::BaseInference; + validationOptions.profile = {"none"}; validationOptions.StrictOperationSpecAlignment = true; validationOptions.level = tosa::TosaLevelEnum::EightK; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index b78c372af77e6..e390a613b5807 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -405,14 +405,28 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { if (level == TosaLevelEnum::EightK) { tosaLevel = TOSA_LEVEL_EIGHTK; } + + if (!profile.empty()) { + for (std::string &prof : profile) { + auto profSymbol = symbolizeTosaProfileEnum(prof); + if (profSymbol) { + enabled_profiles.push_back(profSymbol.value()); + } + } + } } bool CheckVariable(Operation *op); bool CheckVariableReadOrWrite(Operation *op); bool isValidElementType(Type type); + bool isEnabledProfile(TosaProfileEnum prof) { + return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) != + std::end(enabled_profiles); + } SmallVector> constCheckers; + SmallVector enabled_profiles; TosaLevel tosaLevel; DenseMap variablesMap; }; @@ -507,7 +521,7 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) { bool TosaValidation::isValidElementType(Type type) { if (isa(type)) { - if (profile == TosaProfileEnum::BaseInference) + if (!isEnabledProfile(TosaProfileEnum::MainInference)) return false; return type.isF32() || type.isF16() || type.isBF16(); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e5c5b9b366390..b9298b6664353 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1,4 +1,10 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate=strict-op-spec-alignment +//-------------------------------------------------------------------------------------------------- +// Test expected errors in terms of the shape and type of tensor, and the argument type of +// operation. Excludes the profile compilance checking since it is performed earlier in the +// validation flow. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=bi,mi,mt strict-op-spec-alignment" func.func @test_const() -> tensor<1xf32> { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 9b652f2d0bd14..e851019362958 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,4 +1,8 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate +//-------------------------------------------------------------------------------------------------- +// Enable all supported profiles to focus the verification of expected level errors. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=bi,mi,mt" func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {