Skip to content

Commit b42709f

Browse files
committed
[mlir][tosa] Add specification versioning to target environment
This commit adds a new "specification_version" field to the TOSA target environment attribute. This allows a user to specify which version of the TOSA specification they would like to target during lowering. A leading example in the validation pass has also been added. This addition adds a version to each profile compliance entry to track which version of the specification the entry was added. This allows a backwards compatibility check to be implemented between the target version and the profile compliance entry version. For now a default version of "1.0" is assumed. "1.1.draft" is added to denote an in-development version of the specification targeting the next release. Change-Id: I6549e05bd4fe975d12ea31e8acc783233db66171
1 parent 16ad97e commit b42709f

File tree

11 files changed

+821
-354
lines changed

11 files changed

+821
-354
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
5050
/// returned by getDefaultTargetEnv() if not provided.
5151
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
5252

53+
/// A thin wrapper around the SpecificationVersion enum to represent
54+
/// and provide utilities around the TOSA specification version.
55+
class TosaSpecificationVersion {
56+
public:
57+
TosaSpecificationVersion(uint32_t major, uint32_t minor)
58+
: majorVersion(major), minorVersion(minor) {}
59+
TosaSpecificationVersion(SpecificationVersion version)
60+
: TosaSpecificationVersion(fromVersionEnum(version)) {}
61+
62+
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const {
63+
return this->majorVersion == baseVersion.majorVersion &&
64+
this->minorVersion >= baseVersion.minorVersion;
65+
}
66+
67+
uint32_t getMajor() const { return majorVersion; }
68+
uint32_t getMinor() const { return minorVersion; }
69+
70+
private:
71+
uint32_t majorVersion = 0;
72+
uint32_t minorVersion = 0;
73+
74+
static TosaSpecificationVersion
75+
fromVersionEnum(SpecificationVersion version) {
76+
switch (version) {
77+
case SpecificationVersion::V_1_0:
78+
return TosaSpecificationVersion(1, 0);
79+
case SpecificationVersion::V_1_1_DRAFT:
80+
return TosaSpecificationVersion(1, 1);
81+
}
82+
llvm_unreachable("Unknown TOSA version");
83+
}
84+
};
85+
86+
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
87+
5388
/// This class represents the capability enabled in the target implementation
5489
/// such as profile, extension, and level. It's a wrapper class around
5590
/// tosa::TargetEnvAttr.
5691
class TargetEnv {
5792
public:
5893
TargetEnv() {}
59-
explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
94+
explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
95+
const ArrayRef<Profile> &profiles,
6096
const ArrayRef<Extension> &extensions)
61-
: level(level) {
97+
: specificationVersion(specificationVersion), level(level) {
6298
enabledProfiles.insert_range(profiles);
6399
enabledExtensions.insert_range(extensions);
64100
}
65101

66102
explicit TargetEnv(TargetEnvAttr targetAttr)
67-
: TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
68-
targetAttr.getExtensions()) {}
103+
: TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
104+
targetAttr.getProfiles(), targetAttr.getExtensions()) {}
69105

70106
void addProfile(Profile p) { enabledProfiles.insert(p); }
71107
void addExtension(Extension e) { enabledExtensions.insert(e); }
72108

73-
// TODO implement the following utilities.
74-
// Version getSpecVersion() const;
109+
SpecificationVersion getSpecVersion() const { return specificationVersion; }
75110

76111
TosaLevel getLevel() const {
77112
if (level == Level::eightK)
@@ -105,6 +140,7 @@ class TargetEnv {
105140
}
106141

107142
private:
143+
SpecificationVersion specificationVersion;
108144
Level level;
109145
llvm::SmallSet<Profile, 3> enabledProfiles;
110146
llvm::SmallSet<Extension, 13> enabledExtensions;

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 639 additions & 301 deletions
Large diffs are not rendered by default.

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
221221
}
222222

223223
//===----------------------------------------------------------------------===//
224-
// TOSA Spec Section 1.5.
224+
// TOSA Profiles and extensions
225225
//
226226
// Profile:
227227
// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values.
@@ -293,12 +293,6 @@ def Tosa_ExtensionAttr
293293
def Tosa_ExtensionArrayAttr
294294
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
295295

296-
def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
297-
def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
298-
299-
def Tosa_LevelAttr
300-
: Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
301-
302296
// The base class for defining op availability dimensions.
303297
class Availability {
304298
// The following are fields for controlling the generated C++ OpInterface.
@@ -404,18 +398,41 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
404398
let instance = "ref";
405399
}
406400

401+
//===----------------------------------------------------------------------===//
402+
// TOSA Levels
403+
//===----------------------------------------------------------------------===//
404+
405+
def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
406+
def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
407+
408+
def Tosa_LevelAttr
409+
: Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
410+
411+
//===----------------------------------------------------------------------===//
412+
// TOSA Specification versions
413+
//===----------------------------------------------------------------------===//
414+
415+
def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">;
416+
def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">;
417+
418+
def Tosa_SpecificationVersion : Tosa_I32EnumAttr<
419+
"SpecificationVersion", "TOSA specification version", "specification_version",
420+
[Tosa_V_1_0, Tosa_V_1_1_DRAFT]>;
421+
407422
//===----------------------------------------------------------------------===//
408423
// TOSA target environment.
409424
//===----------------------------------------------------------------------===//
410425
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
411426
let summary = "Target environment information.";
412427
let parameters = ( ins
428+
"SpecificationVersion": $specification_version,
413429
"Level": $level,
414430
ArrayRefParameter<"Profile">: $profiles,
415431
ArrayRefParameter<"Extension">: $extensions
416432
);
417433

418-
let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
434+
let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` "
435+
"`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
419436
"`extensions` `=` `[` $extensions `]` `>`";
420437
}
421438

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@ enum CheckCondition {
3636
allOf
3737
};
3838

39+
using VersionedTypeInfo =
40+
std::pair<SmallVector<TypeInfo>, SpecificationVersion>;
41+
3942
template <typename T>
4043
struct OpComplianceInfo {
4144
// Certain operations require multiple modes enabled.
4245
// e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
4346
SmallVector<T> mode;
44-
SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet;
47+
SmallVector<VersionedTypeInfo> operandTypeInfoSet;
4548
CheckCondition condition = CheckCondition::anyOf;
4649
};
4750

@@ -130,9 +133,8 @@ class TosaProfileCompliance {
130133
// Find the required profiles or extensions from the compliance info according
131134
// to the operand type combination.
132135
template <typename T>
133-
SmallVector<T> findMatchedProfile(Operation *op,
134-
SmallVector<OpComplianceInfo<T>> compInfo,
135-
CheckCondition &condition);
136+
OpComplianceInfo<T>
137+
findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
136138

137139
SmallVector<Profile> getCooperativeProfiles(Extension ext) {
138140
switch (ext) {
@@ -168,8 +170,7 @@ class TosaProfileCompliance {
168170

169171
private:
170172
template <typename T>
171-
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
172-
CheckCondition &condition);
173+
FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);
173174

174175
OperationProfileComplianceMap profileComplianceMap;
175176
OperationExtensionComplianceMap extensionComplianceMap;

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
137137
];
138138

139139
let options = [
140+
Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion",
141+
/*default=*/"mlir::tosa::SpecificationVersion::V_1_0",
142+
"The specification version that TOSA operators should conform to.",
143+
[{::llvm::cl::values(
144+
clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"),
145+
clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft")
146+
)}]>,
140147
Option<"level", "level", "mlir::tosa::Level",
141148
/*default=*/"mlir::tosa::Level::eightK",
142149
"The TOSA level that operators should conform to. A TOSA level defines "

mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
10+
#include "llvm/Support/FormatVariadic.h"
1011

1112
namespace mlir {
1213
namespace tosa {
@@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) {
2728
}
2829

2930
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
30-
return TargetEnvAttr::get(context, Level::eightK,
31+
return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK,
3132
{Profile::pro_int, Profile::pro_fp}, {});
3233
}
3334

@@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
3839
return getDefaultTargetEnv(op->getContext());
3940
}
4041

42+
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
43+
return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
44+
}
45+
4146
} // namespace tosa
4247
} // namespace mlir

mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ class TosaAttachTarget
6161

6262
ModuleOp mod = getOperation();
6363
MLIRContext *ctx = &getContext();
64-
const auto targetEnvAttr =
65-
TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
64+
const auto targetEnvAttr = TargetEnvAttr::get(
65+
ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
6666
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
6767
}
6868

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
335335
//===----------------------------------------------------------------------===//
336336

337337
template <typename T>
338-
FailureOr<SmallVector<T>>
339-
TosaProfileCompliance::getOperatorDefinition(Operation *op,
340-
CheckCondition &condition) {
338+
FailureOr<OpComplianceInfo<T>>
339+
TosaProfileCompliance::getOperatorDefinition(Operation *op) {
341340
const std::string opName = op->getName().getStringRef().str();
342341
const auto complianceMap = getProfileComplianceMap<T>();
343342
const auto it = complianceMap.find(opName);
344343
if (it == complianceMap.end())
345344
return {};
346345

347-
return findMatchedProfile<T>(op, it->second, condition);
346+
return findMatchedEntry<T>(op, it->second);
348347
}
349348

350349
template <typename T>
@@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
356355
if (specRequiredModeSet.size() == 0)
357356
return success();
358357

359-
CheckCondition condition = CheckCondition::invalid;
360-
const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
361-
if (failed(maybeOpRequiredMode)) {
358+
const auto maybeOpDefinition = getOperatorDefinition<T>(op);
359+
if (failed(maybeOpDefinition)) {
362360
// Operators such as control-flow and shape ops do not have an operand type
363361
// restriction. When the profile compliance information of operation is not
364362
// found, confirm if the target have enabled the profile required from the
365363
// specification.
366-
int mode_count = 0;
364+
int modeCount = 0;
367365
for (const auto &cands : specRequiredModeSet) {
368366
if (targetEnv.allowsAnyOf(cands))
369367
return success();
370-
mode_count += cands.size();
368+
modeCount += cands.size();
371369
}
372370

373371
op->emitOpError() << "illegal: requires"
374-
<< (mode_count > 1 ? " any of " : " ") << "["
372+
<< (modeCount > 1 ? " any of " : " ") << "["
375373
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
376374
", ")
377375
<< "] but not enabled in target\n";
@@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
381379

382380
// Find the required profiles or extensions according to the operand type
383381
// combination.
384-
const auto opRequiredMode = maybeOpRequiredMode.value();
382+
const auto opDefinition = maybeOpDefinition.value();
383+
const SmallVector<T> opRequiredMode = opDefinition.mode;
384+
const CheckCondition condition = opDefinition.condition;
385+
385386
if (opRequiredMode.size() == 0) {
386387
// No matched restriction found.
387388
return success();
@@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
437438
}
438439
}
439440

441+
// Ensure the matched op compliance version does not exceed the target
442+
// specification version.
443+
const VersionedTypeInfo versionedTypeInfo =
444+
opDefinition.operandTypeInfoSet[0];
445+
const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
446+
const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
447+
if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
448+
op->emitOpError() << "illegal: requires the target specification version ("
449+
<< stringifyVersion(targetVersion)
450+
<< ") be backwards compatible with the op compliance "
451+
"specification version ("
452+
<< stringifyVersion(complianceVersion) << ")\n";
453+
return failure();
454+
}
455+
440456
return success();
441457
}
442458

@@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op,
461477
}
462478

463479
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
464-
CheckCondition condition = CheckCondition::invalid;
465-
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
466-
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
480+
const auto maybeProfDef = getOperatorDefinition<Profile>(op);
481+
const auto maybeExtDef = getOperatorDefinition<Extension>(op);
467482
if (failed(maybeProfDef) && failed(maybeExtDef))
468483
return success();
469484

470-
const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
471-
(succeeded(maybeExtDef) && !maybeExtDef->empty());
485+
const bool hasEntry =
486+
(succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
487+
(succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
472488
if (!hasEntry) {
473489
std::string message;
474490
llvm::raw_string_ostream os(message);
@@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
488504
SmallVector<TypeInfo> bestTypeInfo;
489505
const auto searchBestMatch = [&](auto map) {
490506
for (const auto &complianceInfos : map[opName]) {
491-
for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
507+
for (const auto &versionedTypeInfos :
508+
complianceInfos.operandTypeInfoSet) {
509+
const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
492510
const int matches = llvm::count_if(
493511
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
494512
return isSameTypeInfo(std::get<0>(zipType),
@@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
520538
// Find the profiles or extensions requirement according to the signature of
521539
// type of the operand list.
522540
template <typename T>
523-
SmallVector<T> TosaProfileCompliance::findMatchedProfile(
524-
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
525-
CheckCondition &condition) {
541+
OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry(
542+
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) {
526543
assert(compInfo.size() != 0 &&
527544
"profile-based compliance information is empty");
528545

@@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
533550
return {};
534551

535552
for (size_t i = 0; i < compInfo.size(); i++) {
536-
SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
537-
for (SmallVector<TypeInfo> expected : sets) {
553+
SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
554+
for (const auto &set : sets) {
555+
SmallVector<TypeInfo> expected = set.first;
538556
assert(present.size() == expected.size() &&
539557
"the entries for profile-based compliance do not match between "
540558
"the generated metadata and the type definition retrieved from "
541559
" the operation");
542560

543-
bool is_found = true;
561+
bool isFound = true;
544562
// Compare the type signature between the given operation and the
545563
// compliance metadata.
546564
for (size_t j = 0; j < expected.size(); j++) {
547565
if (!isSameTypeInfo(present[j], expected[j])) {
548566
// Verify the next mode set from the list.
549-
is_found = false;
567+
isFound = false;
550568
break;
551569
}
552570
}
553571

554-
if (is_found == true) {
555-
condition = compInfo[i].condition;
556-
return compInfo[i].mode;
572+
if (isFound == true) {
573+
SmallVector<VersionedTypeInfo> typeInfoSet{set};
574+
OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
575+
compInfo[i].condition};
576+
return info;
557577
}
558578
}
559579
}

0 commit comments

Comments
 (0)