Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
/// returned by getDefaultTargetEnv() if not provided.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);

/// A thin wrapper around the SpecificationVersion enum to represent
/// and provide utilities around the TOSA specification version.
class TosaSpecificationVersion {
public:
TosaSpecificationVersion(uint32_t major, uint32_t minor)
: majorVersion(major), minorVersion(minor) {}
TosaSpecificationVersion(SpecificationVersion version)
: TosaSpecificationVersion(fromVersionEnum(version)) {}

bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const {
return this->majorVersion == baseVersion.majorVersion &&
this->minorVersion >= baseVersion.minorVersion;
}

uint32_t getMajor() const { return majorVersion; }
uint32_t getMinor() const { return minorVersion; }

private:
uint32_t majorVersion = 0;
uint32_t minorVersion = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tosa spec also has patch, draft fields.
Are those less useful here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I feel patch versions of the spec likely won't be useful due to the nature of the changes, since they cannot add functionality, only make clarifications. It should be relatively easy to extend in the future if there's a case however.

Similarly for draft, though I'd be more hesitant to add since any feature checks for draft will need to become the non-draft version at some point anyway (that is, we should not be holding onto draft versions). Note that this patch does use 1.1.draft for the user-facing API, at least until 1.1 is released.


static TosaSpecificationVersion
fromVersionEnum(SpecificationVersion version) {
switch (version) {
case SpecificationVersion::V_1_0:
return TosaSpecificationVersion(1, 0);
case SpecificationVersion::V_1_1_DRAFT:
return TosaSpecificationVersion(1, 1);
}
llvm_unreachable("Unknown TOSA version");
}
};

llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);

/// This class represents the capability enabled in the target implementation
/// such as profile, extension, and level. It's a wrapper class around
/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
: level(level) {
: specificationVersion(specificationVersion), level(level) {
enabledProfiles.insert_range(profiles);
enabledExtensions.insert_range(extensions);
}

explicit TargetEnv(TargetEnvAttr targetAttr)
: TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
targetAttr.getExtensions()) {}
: TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
targetAttr.getProfiles(), targetAttr.getExtensions()) {}

void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }

// TODO implement the following utilities.
// Version getSpecVersion() const;
SpecificationVersion getSpecVersion() const { return specificationVersion; }

TosaLevel getLevel() const {
if (level == Level::eightK)
Expand Down Expand Up @@ -105,6 +140,7 @@ class TargetEnv {
}

private:
SpecificationVersion specificationVersion;
Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 13> enabledExtensions;
Expand Down
940 changes: 639 additions & 301 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Large diffs are not rendered by default.

33 changes: 25 additions & 8 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
}

//===----------------------------------------------------------------------===//
// TOSA Spec Section 1.5.
// TOSA Profiles and extensions
//
// Profile:
// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values.
Expand Down Expand Up @@ -293,12 +293,6 @@ def Tosa_ExtensionAttr
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;

def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;

def Tosa_LevelAttr
: Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;

// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
Expand Down Expand Up @@ -404,18 +398,41 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}

//===----------------------------------------------------------------------===//
// TOSA Levels
//===----------------------------------------------------------------------===//

def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;

def Tosa_LevelAttr
: Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;

//===----------------------------------------------------------------------===//
// TOSA Specification versions
//===----------------------------------------------------------------------===//

def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">;
def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">;

def Tosa_SpecificationVersion : Tosa_I32EnumAttr<
"SpecificationVersion", "TOSA specification version", "specification_version",
[Tosa_V_1_0, Tosa_V_1_1_DRAFT]>;

//===----------------------------------------------------------------------===//
// TOSA target environment.
//===----------------------------------------------------------------------===//
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
let summary = "Target environment information.";
let parameters = ( ins
"SpecificationVersion": $specification_version,
"Level": $level,
ArrayRefParameter<"Profile">: $profiles,
ArrayRefParameter<"Extension">: $extensions
);

let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` "
"`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
"`extensions` `=` `[` $extensions `]` `>`";
}

Expand Down
13 changes: 7 additions & 6 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ enum CheckCondition {
allOf
};

using VersionedTypeInfo =
std::pair<SmallVector<TypeInfo>, SpecificationVersion>;

template <typename T>
struct OpComplianceInfo {
// Certain operations require multiple modes enabled.
// e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
SmallVector<T> mode;
SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet;
SmallVector<VersionedTypeInfo> operandTypeInfoSet;
CheckCondition condition = CheckCondition::anyOf;
};

Expand Down Expand Up @@ -130,9 +133,8 @@ class TosaProfileCompliance {
// Find the required profiles or extensions from the compliance info according
// to the operand type combination.
template <typename T>
SmallVector<T> findMatchedProfile(Operation *op,
SmallVector<OpComplianceInfo<T>> compInfo,
CheckCondition &condition);
OpComplianceInfo<T>
findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);

SmallVector<Profile> getCooperativeProfiles(Extension ext) {
switch (ext) {
Expand Down Expand Up @@ -168,8 +170,7 @@ class TosaProfileCompliance {

private:
template <typename T>
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
CheckCondition &condition);
FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);

OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
];

let options = [
Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion",
/*default=*/"mlir::tosa::SpecificationVersion::V_1_0",
"The specification version that TOSA operators should conform to.",
[{::llvm::cl::values(
clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"),
clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft")
)}]>,
Option<"level", "level", "mlir::tosa::Level",
/*default=*/"mlir::tosa::Level::eightK",
"The TOSA level that operators should conform to. A TOSA level defines "
Expand Down
7 changes: 6 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "llvm/Support/FormatVariadic.h"

namespace mlir {
namespace tosa {
Expand All @@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) {
}

TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
return TargetEnvAttr::get(context, Level::eightK,
return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK,
{Profile::pro_int, Profile::pro_fp}, {});
}

Expand All @@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}

llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
}

} // namespace tosa
} // namespace mlir
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ class TosaAttachTarget

ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
const auto targetEnvAttr =
TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
const auto targetEnvAttr = TargetEnvAttr::get(
ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
}

Expand Down
74 changes: 47 additions & 27 deletions mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
//===----------------------------------------------------------------------===//

template <typename T>
FailureOr<SmallVector<T>>
TosaProfileCompliance::getOperatorDefinition(Operation *op,
CheckCondition &condition) {
FailureOr<OpComplianceInfo<T>>
TosaProfileCompliance::getOperatorDefinition(Operation *op) {
const std::string opName = op->getName().getStringRef().str();
const auto complianceMap = getProfileComplianceMap<T>();
const auto it = complianceMap.find(opName);
if (it == complianceMap.end())
return {};

return findMatchedProfile<T>(op, it->second, condition);
return findMatchedEntry<T>(op, it->second);
}

template <typename T>
Expand All @@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();

CheckCondition condition = CheckCondition::invalid;
const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
if (failed(maybeOpRequiredMode)) {
const auto maybeOpDefinition = getOperatorDefinition<T>(op);
if (failed(maybeOpDefinition)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
// specification.
int mode_count = 0;
int modeCount = 0;
for (const auto &cands : specRequiredModeSet) {
if (targetEnv.allowsAnyOf(cands))
return success();
mode_count += cands.size();
modeCount += cands.size();
}

op->emitOpError() << "illegal: requires"
<< (mode_count > 1 ? " any of " : " ") << "["
<< (modeCount > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
", ")
<< "] but not enabled in target\n";
Expand All @@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(

// Find the required profiles or extensions according to the operand type
// combination.
const auto opRequiredMode = maybeOpRequiredMode.value();
const auto opDefinition = maybeOpDefinition.value();
const SmallVector<T> opRequiredMode = opDefinition.mode;
const CheckCondition condition = opDefinition.condition;

if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
Expand Down Expand Up @@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
}
}

// Ensure the matched op compliance version does not exceed the target
// specification version.
const VersionedTypeInfo versionedTypeInfo =
opDefinition.operandTypeInfoSet[0];
const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
op->emitOpError() << "illegal: requires the target specification version ("
<< stringifyVersion(targetVersion)
<< ") be backwards compatible with the op compliance "
"specification version ("
<< stringifyVersion(complianceVersion) << ")\n";
return failure();
}

return success();
}

Expand All @@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op,
}

LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
const auto maybeProfDef = getOperatorDefinition<Profile>(op);
const auto maybeExtDef = getOperatorDefinition<Extension>(op);
if (failed(maybeProfDef) && failed(maybeExtDef))
return success();

const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
(succeeded(maybeExtDef) && !maybeExtDef->empty());
const bool hasEntry =
(succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
(succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
Expand All @@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
SmallVector<TypeInfo> bestTypeInfo;
const auto searchBestMatch = [&](auto map) {
for (const auto &complianceInfos : map[opName]) {
for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
for (const auto &versionedTypeInfos :
complianceInfos.operandTypeInfoSet) {
const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
const int matches = llvm::count_if(
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
return isSameTypeInfo(std::get<0>(zipType),
Expand Down Expand Up @@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
SmallVector<T> TosaProfileCompliance::findMatchedProfile(
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
CheckCondition &condition) {
OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry(
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) {
assert(compInfo.size() != 0 &&
"profile-based compliance information is empty");

Expand All @@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
return {};

for (size_t i = 0; i < compInfo.size(); i++) {
SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
for (SmallVector<TypeInfo> expected : sets) {
SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
for (const auto &set : sets) {
SmallVector<TypeInfo> expected = set.first;
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
"the generated metadata and the type definition retrieved from "
" the operation");

bool is_found = true;
bool isFound = true;
// Compare the type signature between the given operation and the
// compliance metadata.
for (size_t j = 0; j < expected.size(); j++) {
if (!isSameTypeInfo(present[j], expected[j])) {
// Verify the next mode set from the list.
is_found = false;
isFound = false;
break;
}
}

if (is_found == true) {
condition = compInfo[i].condition;
return compInfo[i].mode;
if (isFound == true) {
SmallVector<VersionedTypeInfo> typeInfoSet{set};
OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
compInfo[i].condition};
return info;
}
}
}
Expand Down
Loading
Loading