Skip to content

Commit 16ad97e

Browse files
authored
[mlir][tosa] Add the concept of a TOSA target environment (llvm#153771)
This commit introduces a new module-level attribute `tosa.target_env`. It encapsulates target information for use during compilation such as: level, profiles and extensions. For example: ```mlir module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int], extensions = [int16, int4]>} { <my-tosa-program> } ``` Previously the validation pass accepted target information as a series of command line pass options. This commit changes the behaviour to query the attached target environment from the module attribute. This refactoring allows other passes to query the same target information. A new target environment can be attached using the `--tosa-attach-target` pass, which takes the same command line options as the previous validation pass arguments. For example: ```bash mlir-opt --tosa-attach-target="profiles=pro_int extensions=int4,int16 level=none" test.mlir ```
1 parent fddd1b6 commit 16ad97e

23 files changed

+307
-131
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ void addTosaToLinalgPasses(
3939
TosaToLinalgNamedOptions(),
4040
// Note: Default to 'none' level unless otherwise specified.
4141
std::optional<tosa::TosaValidationOptions> validationOptions =
42-
tosa::TosaValidationOptions{
43-
{"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
42+
tosa::TosaValidationOptions{false, false});
4443

4544
/// Populates TOSA to linalg pipelines
4645
/// Currently, this includes only the "tosa-to-linalg-pipeline".

mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,67 @@
2020
namespace mlir {
2121
namespace tosa {
2222

23+
struct TosaLevel {
24+
int32_t MAX_RANK = 0;
25+
int32_t MAX_KERNEL = 0;
26+
int32_t MAX_STRIDE = 0;
27+
int32_t MAX_SCALE = 0;
28+
int32_t MAX_LOG2_SIZE = 0;
29+
int32_t MAX_NESTING = 0;
30+
int32_t MAX_TENSOR_LIST_SIZE = 0;
31+
32+
bool operator==(const TosaLevel &rhs) {
33+
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
34+
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
35+
MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
36+
MAX_NESTING == rhs.MAX_NESTING &&
37+
MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
38+
}
39+
};
40+
41+
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
42+
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
43+
63, 256, 256};
44+
45+
TargetEnvAttr lookupTargetEnv(Operation *op);
46+
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
47+
48+
/// Queries the target environment recursively from enclosing symbol table ops
49+
/// containing the given `op` or returns the default target environment as
50+
/// returned by getDefaultTargetEnv() if not provided.
51+
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
52+
2353
/// This class represents the capability enabled in the target implementation
24-
/// such as profile, extension, and level.
54+
/// such as profile, extension, and level. It's a wrapper class around
55+
/// tosa::TargetEnvAttr.
2556
class TargetEnv {
2657
public:
2758
TargetEnv() {}
28-
explicit TargetEnv(const SmallVectorImpl<Profile> &profiles,
29-
const SmallVectorImpl<Extension> &extensions) {
59+
explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
60+
const ArrayRef<Extension> &extensions)
61+
: level(level) {
3062
enabledProfiles.insert_range(profiles);
31-
3263
enabledExtensions.insert_range(extensions);
3364
}
3465

66+
explicit TargetEnv(TargetEnvAttr targetAttr)
67+
: TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
68+
targetAttr.getExtensions()) {}
69+
3570
void addProfile(Profile p) { enabledProfiles.insert(p); }
3671
void addExtension(Extension e) { enabledExtensions.insert(e); }
3772

3873
// TODO implement the following utilities.
3974
// Version getSpecVersion() const;
40-
// TosaLevel getLevel() const;
75+
76+
TosaLevel getLevel() const {
77+
if (level == Level::eightK)
78+
return TOSA_LEVEL_EIGHTK;
79+
else if (level == Level::none)
80+
return TOSA_LEVEL_NONE;
81+
else
82+
llvm_unreachable("Unknown TOSA level");
83+
};
4184

4285
// Returns true if the given profile is allowed.
4386
bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
@@ -62,8 +105,9 @@ class TargetEnv {
62105
}
63106

64107
private:
108+
Level level;
65109
llvm::SmallSet<Profile, 3> enabledProfiles;
66-
llvm::SmallSet<Extension, 8> enabledExtensions;
110+
llvm::SmallSet<Extension, 13> enabledExtensions;
67111
};
68112

69113
} // namespace tosa

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>;
245245
def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>;
246246
def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>;
247247

248+
def Tosa_ProfileAttr
249+
: Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
250+
[Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> {
251+
let extraClassDeclaration = [{
252+
static llvm::SmallVector<Profile, 2> getAllValues() {
253+
return {Profile::pro_int, Profile::pro_fp};
254+
}
255+
}];
256+
}
257+
258+
def Tosa_ProfileArrayAttr
259+
: TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
260+
248261
def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>;
249262
def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
250263
def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
@@ -264,17 +277,27 @@ def Tosa_ExtensionAttr
264277
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
265278
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
266279
Tosa_EXT_DYNAMIC
267-
]>;
280+
]> {
281+
let extraClassDeclaration = [{
282+
static llvm::SmallVector<Extension, 11> getAllValues() {
283+
return {
284+
Extension::int16, Extension::int4, Extension::bf16,
285+
Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
286+
Extension::variable, Extension::controlflow, Extension::doubleround,
287+
Extension::inexactround, Extension::dynamic
288+
};
289+
}
290+
}];
291+
}
268292

269293
def Tosa_ExtensionArrayAttr
270294
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
271295

272-
def Tosa_ProfileAttr
273-
: Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
274-
[Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>;
296+
def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
297+
def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
275298

276-
def Tosa_ProfileArrayAttr
277-
: TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
299+
def Tosa_LevelAttr
300+
: Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
278301

279302
// The base class for defining op availability dimensions.
280303
class Availability {
@@ -381,6 +404,21 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
381404
let instance = "ref";
382405
}
383406

407+
//===----------------------------------------------------------------------===//
408+
// TOSA target environment.
409+
//===----------------------------------------------------------------------===//
410+
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
411+
let summary = "Target environment information.";
412+
let parameters = ( ins
413+
"Level": $level,
414+
ArrayRefParameter<"Profile">: $profiles,
415+
ArrayRefParameter<"Extension">: $extensions
416+
);
417+
418+
let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
419+
"`extensions` `=` `[` $extensions `]` `>`";
420+
}
421+
384422
//===----------------------------------------------------------------------===//
385423
// Iterable attributes.
386424
//===----------------------------------------------------------------------===//
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
set(LLVM_TARGET_DEFINITIONS Passes.td)
22
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
3-
mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
4-
mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
53
add_mlir_dialect_tablegen_target(MLIRTosaPassIncGen)
64

75
add_mlir_doc(Passes TosaPasses ./ -gen-pass-doc)

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1717
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
18-
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
1918
#include "mlir/Pass/Pass.h"
2019

2120
namespace mlir {

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass
6565
}];
6666
}
6767

68-
def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
69-
[
70-
I32EnumAttrCase<"None", 0, "none">,
71-
I32EnumAttrCase<"EightK", 1, "8k">,
72-
]>{
73-
let cppNamespace = "mlir::tosa";
74-
}
75-
7668
def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
7769
let summary = "Validates TOSA dialect";
7870
let description = [{
@@ -81,28 +73,14 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
8173
}];
8274

8375
let options = [
84-
ListOption<"profile", "profile", "std::string",
85-
"Validate if operations match for the given profile set">,
86-
ListOption<"extension", "extension", "std::string",
87-
"Validate if operations match for the given extension set">,
8876
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
8977
/*default=*/"false",
9078
"Verify if the properties of certain operations align the spec requirement">,
9179
Option<"allowInvalidOpDatatypeCombinations", "allow-invalid-op-datatype-combinations", "bool",
9280
/*default=*/"false",
9381
"Disable checks for operations that are determined to be invalid due to their "
9482
"operand/result datatypes not aligning with the 'Supported Data Types' "
95-
"sections of the specifciation">,
96-
Option<"level", "level", "mlir::tosa::TosaLevelEnum",
97-
/*default=*/"mlir::tosa::TosaLevelEnum::EightK",
98-
"Validate if operator parameters are within specfication for the given level",
99-
[{::llvm::cl::values(
100-
clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k",
101-
"Ranges are expected to be sufficient for applications with frame sizes up to 8K."),
102-
clEnumValN(mlir::tosa::TosaLevelEnum::None, "none",
103-
"Allows the full range of arguments specified by the operations according "
104-
"to the operation data types.")
105-
)}]>
83+
"sections of the specifciation">
10684
];
10785
}
10886

@@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle
141119
}];
142120
}
143121

122+
def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
123+
let summary = "Attach tosa.target_env information to the given module.";
124+
125+
let description = [{
126+
This pass allows the user to specify a TOSA target environment consisting of
127+
the following components: level, profiles and extensions.
128+
129+
The target environment is attached to the module as an attribute, allowing other
130+
transformations to query the selected target and adapt their behaviour based on
131+
this information.
132+
}];
133+
134+
let dependentDialects = [
135+
"func::FuncDialect",
136+
"tosa::TosaDialect",
137+
];
138+
139+
let options = [
140+
Option<"level", "level", "mlir::tosa::Level",
141+
/*default=*/"mlir::tosa::Level::eightK",
142+
"The TOSA level that operators should conform to. A TOSA level defines "
143+
"operator argument ranges that an implementation shall support.",
144+
[{::llvm::cl::values(
145+
clEnumValN(mlir::tosa::Level::eightK, "8k",
146+
"Ranges are expected to be sufficient for applications with frame "
147+
"sizes up to 8K."),
148+
clEnumValN(mlir::tosa::Level::none, "none",
149+
"Allows the full range of arguments specified by the operations according "
150+
"to the operation data types.")
151+
)}]>,
152+
ListOption<"profiles", "profiles", "std::string",
153+
"The TOSA profile(s) that operators should conform to. TOSA profiles "
154+
"enable efficient implementation on different classes of device. Each "
155+
"profile is an independent set of operations and data type combinations.">,
156+
ListOption<"extensions", "extensions", "std::string",
157+
"The TOSA extension(s) that operators should conform to. TOSA profile "
158+
"extensions define optional operation and data type combinations.">
159+
];
160+
}
161+
144162
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
115115
TosaToLinalgOptions tosaToLinalgOptions;
116116
TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
117117
TosaValidationOptions validationOptions;
118-
validationOptions.profile = {"none"};
119-
validationOptions.extension = {"none"};
120118
validationOptions.strictOpSpecAlignment = false;
121119
validationOptions.allowInvalidOpDatatypeCombinations = false;
122-
validationOptions.level = tosa::TosaLevelEnum::EightK;
123120
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
124121
tosaToLinalgNamedOptions,
125122
validationOptions);

mlir/lib/Dialect/Tosa/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRTosaDialect
22
IR/TosaOps.cpp
33
IR/TosaCanonicalizations.cpp
4+
IR/TargetEnv.cpp
45
Utils/ConversionUtils.cpp
56
Utils/QuantUtils.cpp
67

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
10+
11+
namespace mlir {
12+
namespace tosa {
13+
14+
TargetEnvAttr lookupTargetEnv(Operation *op) {
15+
while (op) {
16+
op = SymbolTable::getNearestSymbolTable(op);
17+
if (!op)
18+
break;
19+
20+
if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name))
21+
return attr;
22+
23+
op = op->getParentOp();
24+
}
25+
26+
return {};
27+
}
28+
29+
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
30+
return TargetEnvAttr::get(context, Level::eightK,
31+
{Profile::pro_int, Profile::pro_fp}, {});
32+
}
33+
34+
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
35+
if (auto attr = lookupTargetEnv(op))
36+
return attr;
37+
38+
return getDefaultTargetEnv(op->getContext());
39+
}
40+
41+
} // namespace tosa
42+
} // namespace mlir

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRTosaTransforms
2+
TosaAttachTarget.cpp
23
TosaConvertIntegerTypeToSignless.cpp
34
TosaDecomposeTransposeConv.cpp
45
TosaDecomposeDepthwise.cpp

0 commit comments

Comments
 (0)