From 303bebc655974eb78a6bab52022779a1b9d4bb65 Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Fri, 15 Mar 2024 15:52:40 -0700 Subject: [PATCH] [tosa] Change the type of profile option to a list In tosa valiation pass, change the type of profile option to ListOption. Now TOSA profiles is turned from hierarchical to composable. Each profile is an independent set, i.e. an target can implement multiple profiles. Set the option to none by default, and limit to profiles if requested. Change-Id: I1fb8d0c1b27eccd768349b6eb4234093313efb57 --- .../mlir/Conversion/TosaToLinalg/TosaToLinalg.h | 4 ++-- .../mlir/Dialect/Tosa/Transforms/Passes.td | 17 +++-------------- .../TosaToLinalg/TosaToLinalgPass.cpp | 2 +- .../Dialect/Tosa/Transforms/TosaValidation.cpp | 16 +++++++++++++++- mlir/test/Dialect/Tosa/invalid.mlir | 8 +++++++- mlir/test/Dialect/Tosa/level_check.mlir | 6 +++++- 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index c84e4f17c38d8..f5ae71cad1524 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -39,8 +39,8 @@ void addTosaToLinalgPasses( TosaToLinalgNamedOptions(), // Note: Default to 'none' level unless otherwise specified. std::optional validationOptions = - tosa::TosaValidationOptions{tosa::TosaProfileEnum::Undefined, false, - tosa::TosaLevelEnum::None}); + tosa::TosaValidationOptions{ + {"none"}, false, tosa::TosaLevelEnum::None}); /// Populates TOSA to linalg pipelines /// Currently, this includes only the "tosa-to-linalg-pipeline". diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index c0352fa88fe08..dac67633769c7 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -76,7 +76,7 @@ def TosaProfileType : I32EnumAttr<"TosaProfileEnum", "Tosa profile", I32EnumAttrCase<"BaseInference", 0, "bi">, I32EnumAttrCase<"MainInference", 1, "mi">, I32EnumAttrCase<"MainTraining", 2, "mt">, - I32EnumAttrCase<"Undefined", 3> + I32EnumAttrCase<"Undefined", 3, "none"> ]>{ let cppNamespace = "mlir::tosa"; } @@ -97,19 +97,8 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { }]; let options = [ - Option<"profile", "profile", "mlir::tosa::TosaProfileEnum", - /*default=*/"mlir::tosa::TosaProfileEnum::Undefined", - "Validate if operations match for the given profile", - [{::llvm::cl::values( - clEnumValN(mlir::tosa::TosaProfileEnum::BaseInference, "bi", - "Use Base Inference profile."), - clEnumValN(mlir::tosa::TosaProfileEnum::MainInference, "mi", - "Use Main Inference profile."), - clEnumValN(mlir::tosa::TosaProfileEnum::MainTraining, "mt", - "Use Main Training profile."), - clEnumValN(mlir::tosa::TosaProfileEnum::Undefined, "undefined", - "Do not define a profile.") - )}]>, + ListOption<"profile", "profile", "std::string", + "Validate if operations match for the given profile set">, Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool", /*default=*/"false", "Verify if the properties of certain operations align the spec requirement">, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index 44036d7c31a91..06a7262c46742 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -115,7 +115,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() { TosaToLinalgOptions tosaToLinalgOptions; TosaToLinalgNamedOptions tosaToLinalgNamedOptions; TosaValidationOptions validationOptions; - validationOptions.profile = tosa::TosaProfileEnum::BaseInference; + validationOptions.profile = {"none"}; validationOptions.StrictOperationSpecAlignment = true; validationOptions.level = tosa::TosaLevelEnum::EightK; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index b78c372af77e6..e390a613b5807 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -405,14 +405,28 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { if (level == TosaLevelEnum::EightK) { tosaLevel = TOSA_LEVEL_EIGHTK; } + + if (!profile.empty()) { + for (std::string &prof : profile) { + auto profSymbol = symbolizeTosaProfileEnum(prof); + if (profSymbol) { + enabled_profiles.push_back(profSymbol.value()); + } + } + } } bool CheckVariable(Operation *op); bool CheckVariableReadOrWrite(Operation *op); bool isValidElementType(Type type); + bool isEnabledProfile(TosaProfileEnum prof) { + return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) != + std::end(enabled_profiles); + } SmallVector> constCheckers; + SmallVector enabled_profiles; TosaLevel tosaLevel; DenseMap variablesMap; }; @@ -507,7 +521,7 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) { bool TosaValidation::isValidElementType(Type type) { if (isa(type)) { - if (profile == TosaProfileEnum::BaseInference) + if (!isEnabledProfile(TosaProfileEnum::MainInference)) return false; return type.isF32() || type.isF16() || type.isBF16(); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e5c5b9b366390..b9298b6664353 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1,4 +1,10 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate=strict-op-spec-alignment +//-------------------------------------------------------------------------------------------------- +// Test expected errors in terms of the shape and type of tensor, and the argument type of +// operation. Excludes the profile compilance checking since it is performed earlier in the +// validation flow. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=bi,mi,mt strict-op-spec-alignment" func.func @test_const() -> tensor<1xf32> { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 9b652f2d0bd14..e851019362958 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,4 +1,8 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate +//-------------------------------------------------------------------------------------------------- +// Enable all supported profiles to focus the verification of expected level errors. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=bi,mi,mt" func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {