From 9c7ebd6f309b3730c7589bbbdc1e24f9a3d1cf39 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 14 Aug 2025 16:18:10 +0000 Subject: [PATCH] [mlir][tosa] Add the concept of a TOSA target envrionment This commit introduces a new module-level attribute `tosa.target_env`. IT encapsulates target information for use during compilation such as: level, profiles and extensions. For example: ```mlir module attributes {tosa.target_env = #tosa.target_env} { } ``` Previously the validation pass accepted target infomation as a series of command line pass options. This commit changes the behaviour to query the attached target environment from the module attribute. This refactoring allows other passes to query the same target information. A new target environment can be atached using the `--tosa-attach-target` pass, which takes the same command line options as the previous validation pass arguments. For example: ```bash mlir-opt --tosa-attach-target="profiles=pro_int extensions=int4,int16 level=none" test.mlir ``` Change-Id: I74a254855f6320dc70b29ae3509997764e3e5d95 --- .../Conversion/TosaToLinalg/TosaToLinalg.h | 3 +- mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h | 56 +++++++++-- .../mlir/Dialect/Tosa/IR/TosaOpBase.td | 50 ++++++++-- .../Dialect/Tosa/Transforms/CMakeLists.txt | 2 - .../mlir/Dialect/Tosa/Transforms/Passes.h | 1 - .../mlir/Dialect/Tosa/Transforms/Passes.td | 64 ++++++++----- .../TosaToLinalg/TosaToLinalgPass.cpp | 3 - mlir/lib/Dialect/Tosa/CMakeLists.txt | 1 + mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 42 +++++++++ .../Dialect/Tosa/Transforms/CMakeLists.txt | 1 + .../Tosa/Transforms/TosaAttachTarget.cpp | 87 +++++++++++++++++ .../Tosa/Transforms/TosaValidation.cpp | 94 ++++--------------- mlir/test/Dialect/Tosa/dynamic_extension.mlir | 2 +- mlir/test/Dialect/Tosa/error_if_check.mlir | 2 +- mlir/test/Dialect/Tosa/invalid.mlir | 2 +- mlir/test/Dialect/Tosa/invalid_extension.mlir | 2 +- mlir/test/Dialect/Tosa/level_check.mlir | 2 +- .../Dialect/Tosa/profile_all_unsupported.mlir | 2 +- .../Tosa/profile_pro_fp_unsupported.mlir | 2 +- .../Tosa/profile_pro_int_unsupported.mlir | 2 +- .../test/Dialect/Tosa/tosa-attach-target.mlir | 14 +++ .../Tosa/tosa-validation-valid-strict.mlir | 2 +- .../Dialect/Tosa/tosa-validation-valid.mlir | 2 +- 23 files changed, 307 insertions(+), 131 deletions(-) create mode 100644 mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp create mode 100644 mlir/test/Dialect/Tosa/tosa-attach-target.mlir diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index f4823858e3893..ab9b9f24ef3dd 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -39,8 +39,7 @@ void addTosaToLinalgPasses( TosaToLinalgNamedOptions(), // Note: Default to 'none' level unless otherwise specified. std::optional validationOptions = - tosa::TosaValidationOptions{ - {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None}); + tosa::TosaValidationOptions{false, false}); /// Populates TOSA to linalg pipelines /// Currently, this includes only the "tosa-to-linalg-pipeline". diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h index 9ee5079559d2b..10491f65d37af 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -20,24 +20,67 @@ namespace mlir { namespace tosa { +struct TosaLevel { + int32_t MAX_RANK = 0; + int32_t MAX_KERNEL = 0; + int32_t MAX_STRIDE = 0; + int32_t MAX_SCALE = 0; + int32_t MAX_LOG2_SIZE = 0; + int32_t MAX_NESTING = 0; + int32_t MAX_TENSOR_LIST_SIZE = 0; + + bool operator==(const TosaLevel &rhs) { + return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && + MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && + MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && + MAX_NESTING == rhs.MAX_NESTING && + MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; + } +}; + +static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; +static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, + 63, 256, 256}; + +TargetEnvAttr lookupTargetEnv(Operation *op); +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); + +/// Queries the target environment recursively from enclosing symbol table ops +/// containing the given `op` or returns the default target environment as +/// returned by getDefaultTargetEnv() if not provided. +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); + /// This class represents the capability enabled in the target implementation -/// such as profile, extension, and level. +/// such as profile, extension, and level. It's a wrapper class around +/// tosa::TargetEnvAttr. class TargetEnv { public: TargetEnv() {} - explicit TargetEnv(const SmallVectorImpl &profiles, - const SmallVectorImpl &extensions) { + explicit TargetEnv(Level level, const ArrayRef &profiles, + const ArrayRef &extensions) + : level(level) { enabledProfiles.insert_range(profiles); - enabledExtensions.insert_range(extensions); } + explicit TargetEnv(TargetEnvAttr targetAttr) + : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(), + targetAttr.getExtensions()) {} + void addProfile(Profile p) { enabledProfiles.insert(p); } void addExtension(Extension e) { enabledExtensions.insert(e); } // TODO implement the following utilities. // Version getSpecVersion() const; - // TosaLevel getLevel() const; + + TosaLevel getLevel() const { + if (level == Level::eightK) + return TOSA_LEVEL_EIGHTK; + else if (level == Level::none) + return TOSA_LEVEL_NONE; + else + llvm_unreachable("Unknown TOSA level"); + }; // Returns true if the given profile is allowed. bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; } @@ -62,8 +105,9 @@ class TargetEnv { } private: + Level level; llvm::SmallSet enabledProfiles; - llvm::SmallSet enabledExtensions; + llvm::SmallSet enabledExtensions; }; } // namespace tosa diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 115a11b346780..8e5f0e3b19391 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>; def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>; def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>; +def Tosa_ProfileAttr + : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof", + [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> { + let extraClassDeclaration = [{ + static llvm::SmallVector getAllValues() { + return {Profile::pro_int, Profile::pro_fp}; + } + }]; +} + +def Tosa_ProfileArrayAttr + : TypedArrayAttrBase; + def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>; def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>; def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>; @@ -264,17 +277,27 @@ def Tosa_ExtensionAttr Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_DYNAMIC - ]>; + ]> { + let extraClassDeclaration = [{ + static llvm::SmallVector getAllValues() { + return { + Extension::int16, Extension::int4, Extension::bf16, + Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft, + Extension::variable, Extension::controlflow, Extension::doubleround, + Extension::inexactround, Extension::dynamic + }; + } + }]; +} def Tosa_ExtensionArrayAttr : TypedArrayAttrBase; -def Tosa_ProfileAttr - : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof", - [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>; +def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; +def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; -def Tosa_ProfileArrayAttr - : TypedArrayAttrBase; +def Tosa_LevelAttr + : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; // The base class for defining op availability dimensions. class Availability { @@ -381,6 +404,21 @@ class Extension extensions> : Availability { let instance = "ref"; } +//===----------------------------------------------------------------------===// +// TOSA target environment. +//===----------------------------------------------------------------------===// +def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> { + let summary = "Target environment information."; + let parameters = ( ins + "Level": $level, + ArrayRefParameter<"Profile">: $profiles, + ArrayRefParameter<"Extension">: $extensions + ); + + let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " + "`extensions` `=` `[` $extensions `]` `>`"; +} + //===----------------------------------------------------------------------===// // Iterable attributes. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt index 7484473c0db23..f52b82a964da9 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,7 +1,5 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) -mlir_tablegen(PassesEnums.h.inc -gen-enum-decls) -mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs) add_mlir_dialect_tablegen_target(MLIRTosaPassIncGen) add_mlir_doc(Passes TosaPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 306e4b1f218e7..ba99d2f1d2727 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -15,7 +15,6 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index b96682843538c..6ae19d81e0820 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass }]; } -def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level", - [ - I32EnumAttrCase<"None", 0, "none">, - I32EnumAttrCase<"EightK", 1, "8k">, - ]>{ - let cppNamespace = "mlir::tosa"; -} - def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { let summary = "Validates TOSA dialect"; let description = [{ @@ -81,10 +73,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { }]; let options = [ - ListOption<"profile", "profile", "std::string", - "Validate if operations match for the given profile set">, - ListOption<"extension", "extension", "std::string", - "Validate if operations match for the given extension set">, Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool", /*default=*/"false", "Verify if the properties of certain operations align the spec requirement">, @@ -92,17 +80,7 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { /*default=*/"false", "Disable checks for operations that are determined to be invalid due to their " "operand/result datatypes not aligning with the 'Supported Data Types' " - "sections of the specifciation">, - Option<"level", "level", "mlir::tosa::TosaLevelEnum", - /*default=*/"mlir::tosa::TosaLevelEnum::EightK", - "Validate if operator parameters are within specfication for the given level", - [{::llvm::cl::values( - clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k", - "Ranges are expected to be sufficient for applications with frame sizes up to 8K."), - clEnumValN(mlir::tosa::TosaLevelEnum::None, "none", - "Allows the full range of arguments specified by the operations according " - "to the operation data types.") - )}]> + "sections of the specifciation"> ]; } @@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle }]; } +def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> { + let summary = "Attach tosa.target_env information to the given module."; + + let description = [{ + This pass allows the user to specify a TOSA target environment consisting of + the following components: level, profiles and extensions. + + The target environment is attached to the module as an attribute, allowing other + transformations to query the selected target and adapt their behaviour based on + this information. + }]; + + let dependentDialects = [ + "func::FuncDialect", + "tosa::TosaDialect", + ]; + + let options = [ + Option<"level", "level", "mlir::tosa::Level", + /*default=*/"mlir::tosa::Level::eightK", + "The TOSA level that operators should conform to. A TOSA level defines " + "operator argument ranges that an implementation shall support.", + [{::llvm::cl::values( + clEnumValN(mlir::tosa::Level::eightK, "8k", + "Ranges are expected to be sufficient for applications with frame " + "sizes up to 8K."), + clEnumValN(mlir::tosa::Level::none, "none", + "Allows the full range of arguments specified by the operations according " + "to the operation data types.") + )}]>, + ListOption<"profiles", "profiles", "std::string", + "The TOSA profile(s) that operators should conform to. TOSA profiles " + "enable efficient implementation on different classes of device. Each " + "profile is an independent set of operations and data type combinations.">, + ListOption<"extensions", "extensions", "std::string", + "The TOSA extension(s) that operators should conform to. TOSA profile " + "extensions define optional operation and data type combinations."> + ]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index c6a3ba9f1439f..e7602b4508cf1 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() { TosaToLinalgOptions tosaToLinalgOptions; TosaToLinalgNamedOptions tosaToLinalgNamedOptions; TosaValidationOptions validationOptions; - validationOptions.profile = {"none"}; - validationOptions.extension = {"none"}; validationOptions.strictOpSpecAlignment = false; validationOptions.allowInvalidOpDatatypeCombinations = false; - validationOptions.level = tosa::TosaLevelEnum::EightK; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, tosaToLinalgNamedOptions, validationOptions); diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index c6a438d348946..a95906aa8352e 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTosaDialect IR/TosaOps.cpp IR/TosaCanonicalizations.cpp + IR/TargetEnv.cpp Utils/ConversionUtils.cpp Utils/QuantUtils.cpp diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp new file mode 100644 index 0000000000000..5aad67173cc61 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -0,0 +1,42 @@ +//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" + +namespace mlir { +namespace tosa { + +TargetEnvAttr lookupTargetEnv(Operation *op) { + while (op) { + op = SymbolTable::getNearestSymbolTable(op); + if (!op) + break; + + if (auto attr = op->getAttrOfType(TargetEnvAttr::name)) + return attr; + + op = op->getParentOp(); + } + + return {}; +} + +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { + return TargetEnvAttr::get(context, Level::eightK, + {Profile::pro_int, Profile::pro_fp}, {}); +} + +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { + if (auto attr = lookupTargetEnv(op)) + return attr; + + return getDefaultTargetEnv(op->getContext()); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 803993bb1008d..41b338d6e7189 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTosaTransforms + TosaAttachTarget.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp new file mode 100644 index 0000000000000..bcb880a808b36 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -0,0 +1,87 @@ +//===- TosaAttachTarget.cpp +//------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Attach target information to a TOSA module. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +#define GEN_PASS_DEF_TOSAATTACHTARGET +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +namespace { + +class TosaAttachTarget + : public tosa::impl::TosaAttachTargetBase { + using Base::Base; + +public: + void runOnOperation() override { + llvm::SmallVector selectedProfiles; + if (!profiles.empty()) { + for (const std::string &prof : profiles) { + std::optional profSymbol = symbolizeProfile(prof); + if (!profSymbol) { + llvm::SmallVector allProfiles = ProfileAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allProfiles, + "profile", prof); + return signalPassFailure(); + } + selectedProfiles.push_back(profSymbol.value()); + } + } + + llvm::SmallVector selectedExtensions; + if (!extensions.empty()) { + for (const std::string &ext : extensions) { + std::optional extSymbol = symbolizeExtension(ext); + if (!extSymbol) { + llvm::SmallVector allExtensions = + ExtensionAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allExtensions, + "extension", ext); + return signalPassFailure(); + } + selectedExtensions.push_back(extSymbol.value()); + } + } + + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + const auto targetEnvAttr = + TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + mod->setAttr(TargetEnvAttr::name, targetEnvAttr); + } + +private: + template + std::string buildUnkownParameterErrorMessage(llvm::SmallVector &enumValues, + std::string enumName, + std::string unknownArgument) { + std::string message; + llvm::raw_string_ostream os(message); + os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument + << "', supported " << enumName << "s are: "; + llvm::interleaveComma(enumValues, os); + os << "\n"; + return message; + } +}; + +} // namespace + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 4fc7ce81d9821..82f2f7eb17af4 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Tosa/IR/TargetEnv.h" #include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" #include @@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op, return success(); } -struct TosaLevel { - int32_t MAX_RANK = 0; - int32_t MAX_KERNEL = 0; - int32_t MAX_STRIDE = 0; - int32_t MAX_SCALE = 0; - int32_t MAX_LOG2_SIZE = 0; - int32_t MAX_NESTING = 0; - int32_t MAX_TENSOR_LIST_SIZE = 0; - - bool operator==(const TosaLevel &rhs) { - return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && - MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && - MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && - MAX_NESTING == rhs.MAX_NESTING && - MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; - } -}; - -static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; -static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, - 63, 256, 256}; - //===----------------------------------------------------------------------===// // TOSA Validation Pass. //===----------------------------------------------------------------------===// @@ -162,12 +139,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { explicit TosaValidation(const TosaValidationOptions &options) : TosaValidation() { - this->profile = options.profile; - this->extension = options.extension; this->strictOpSpecAlignment = options.strictOpSpecAlignment; this->allowInvalidOpDatatypeCombinations = options.allowInvalidOpDatatypeCombinations; - this->level = options.level; } void runOnOperation() final; @@ -207,28 +181,28 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { LogicalResult levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_KERNEL) + if (v > targetEnv.getLevel().MAX_KERNEL) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_STRIDE) + if (v > targetEnv.getLevel().MAX_STRIDE) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_SCALE) + if (v > targetEnv.getLevel().MAX_SCALE) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) + if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE) return op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc; return success(); @@ -285,6 +259,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { template LogicalResult levelCheckRanks(T tosaOp) { auto op = tosaOp.getOperation(); + const TosaLevel tosaLevel = targetEnv.getLevel(); for (auto v : op->getOperands()) { if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))) return failure(); @@ -466,7 +441,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { int32_t maxNestedDepth = 0; getMaxNestedDepth(op, maxNestedDepth); - if (maxNestedDepth >= tosaLevel.MAX_NESTING) { + if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) { op->emitOpError() << "failed level check: " << maxNestedDepth << " >= MAX_NESTING"; return failure(); @@ -523,43 +498,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { return success(); } - // configure profile and level values from pass options profileName and - // levelName - void configLevelAndProfile() { - tosaLevel = TOSA_LEVEL_NONE; - if (level == TosaLevelEnum::EightK) { - tosaLevel = TOSA_LEVEL_EIGHTK; - } - - if (!profile.empty()) { - for (std::string &prof : profile) { - auto profSymbol = symbolizeProfile(prof); - if (profSymbol) { - targetEnv.addProfile(profSymbol.value()); - } else { - llvm::errs() << "unknown TOSA profile name passed in: " << prof - << ", supported profiles are `pro_int` and `pro_fp`\n"; - return signalPassFailure(); - } - } - } - - if (!extension.empty()) { - for (std::string &ext : extension) { - auto extSymbol = symbolizeExtension(ext); - if (extSymbol) { - targetEnv.addExtension(extSymbol.value()); - } else { - llvm::errs() << "unknown TOSA extension name passed in: " << ext - << ", supported extension are int16, int4, bf16, " - << "fp8e4m3, fp8e5m2, fft, variable, controlflow, " - << "doubleround, inexactround and dynamic\n"; - return signalPassFailure(); - } - } - } - } - LogicalResult CheckVariable(Operation *op); LogicalResult CheckVariableReadOrWrite(Operation *op); bool isValidElementType(Type type, const bool allowUnsigned = false); @@ -567,7 +505,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { SmallVector< std::function> constCheckers; - TosaLevel tosaLevel; DenseMap variablesMap; TosaProfileCompliance profileComp; tosa::TargetEnv targetEnv; @@ -576,13 +513,13 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { template <> LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) { auto *op = tosaOp.getOperation(); - if (failed( - levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))) + if (failed(levelCheckRank(op, tosaOp.getInput(), "operand", + targetEnv.getLevel().MAX_RANK))) return failure(); // rank(output) = rank(input) - 1 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result", - tosaLevel.MAX_RANK - 1))) + targetEnv.getLevel().MAX_RANK - 1))) return failure(); return success(); @@ -594,7 +531,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { // Only the condition input has rank limitation. if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand", - tosaLevel.MAX_RANK))) + targetEnv.getLevel().MAX_RANK))) return failure(); return success(); @@ -605,7 +542,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { auto *op = tosaOp.getOperation(); auto variableType = getVariableType(tosaOp); if (failed(levelCheckRank(op, variableType, "variable type", - tosaLevel.MAX_RANK))) + targetEnv.getLevel().MAX_RANK))) return failure(); return success(); @@ -762,7 +699,8 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, // defined in 1.7. Levels. // For each tensor, the number of tensor elements multiplied by the // element size in bytes must be representable as a tensor_size_t. - const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1; + const int64_t max_size = + (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1; if (size > max_size) return op->emitOpError() << "failed level check: " << operandOrResult @@ -772,7 +710,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, } LogicalResult TosaValidation::applyLevelCheck(Operation *op) { - if (tosaLevel == TOSA_LEVEL_NONE) { + if (targetEnv.getLevel() == TOSA_LEVEL_NONE) { // no need to do level checks return success(); } @@ -1282,12 +1220,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { - configLevelAndProfile(); - TosaDialect *tosaDialect = getContext().getLoadedDialect(); if (!tosaDialect) return; + targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); + getOperation().walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir index aaf8371beb1e6..60b70b8754611 100644 --- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir +++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir @@ -2,7 +2,7 @@ // Check operations when the dynamic extension is enabled. //-------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations" // ----- diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index 2f9421c43d2fb..334f52a3407c7 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="level=none profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="level=none profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic" -tosa-validate="strict-op-spec-alignment" // ----- diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 41c3243792259..e92ba93ca5312 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" func.func @test_cast(%arg0: tensor) -> tensor<5xi32> { diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 3138ce2621a3a..d6816c6752d01 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -2,7 +2,7 @@ // Enable all supported profiles to focus the verification of expected extension requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_argmax(%arg0: tensor<14x19xbf16>) -> tensor<14xi32> { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 3742adf650408..90efb48abb37c 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir index 225b962589df9..09e96eca776e2 100644 --- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 58a73d6e2e52b..7ff8065ee41fd 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_f16() -> tensor<3x11x11x3xf16> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir index a5784b381534a..48e79e4000d56 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_i1() -> tensor<3x11x11x3xi1> { diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir new file mode 100644 index 0000000000000..d6c886c44b013 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K +// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT + +// ----- + +// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env} +// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env} +// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env} +// CHECK-LABEL: test_simple +func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> { + %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> + return %1 : tensor<1x1x1x1xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir index f05ae7f58261d..8e0ad0a5e46a7 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" --tosa-validate="strict-op-spec-alignment" | FileCheck %s // ----- diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir index 88ec027277e4f..663159e75d1a6 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate | FileCheck %s // -----