Skip to content

Commit d9b7b22

Browse files
committed
Merge remote-tracking branch 'origin/feature/onnx-to-tosa' into matthias.adjust_for_llvm_bump_to_a58e774f
2 parents f2d62fd + ae99b99 commit d9b7b22

16 files changed

+1351
-981
lines changed

docs/Dialects/onnx.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,45 @@
11
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
2+
### `onnx.AMDQuarkBFPQuantizeDequantizeOp` (AMDQuarkBFPQuantizeDequantizeOp)
3+
4+
_BFPQuantizeDequantize_
5+
6+
Block Floating Point (BFP) groups numbers (e.g., tensors, arrays) into blocks, where each block shares a common exponent, and the values in the block are represented with individual mantissas (and the sign bit). This approach offers the performance and speed of 8-bit operations while bringing the precision closer to 16-bit operations.
7+
8+
MicroeXponents (MX) extends the concept of BFP by introducing two levels of exponents: shared exponents for entire blocks and micro exponents for finer-grained sub-blocks. This two-level approach enables more precise scaling of individual elements within a block, reducing quantization error and improving the representational range. The paper https://arxiv.org/abs/2302.08007 introduces three specific formats: MX4, MX6 and MX9, which have different bits of mantissa.
9+
10+
This operator converts floating-point values (typically 32-bit floating-point numbers) into BFP or MX values, then convert them back. It approximates the Quantize-Dequantize process and introduces quantization errors.
11+
12+
Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>`, `SameOperandsAndResultElementType`
13+
14+
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
15+
16+
Effects: `MemoryEffects::Effect{}`
17+
18+
#### Attributes:
19+
20+
<table>
21+
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
22+
<tr><td><code>bfp_method</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
23+
<tr><td><code>axis</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
24+
<tr><td><code>bit_width</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
25+
<tr><td><code>block_size</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
26+
<tr><td><code>rounding_mode</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
27+
<tr><td><code>sub_block_size</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
28+
<tr><td><code>sub_block_shift_bits</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
29+
</table>
30+
31+
#### Operands:
32+
33+
| Operand | Description |
34+
| :-----: | ----------- |
35+
| `X` | tensor of 32-bit float values
36+
37+
#### Results:
38+
39+
| Result | Description |
40+
| :----: | ----------- |
41+
| `Y` | tensor of 32-bit float values
42+
243
### `onnx.Abs` (ONNXAbsOp)
344

445
_ONNX Abs operation_

docs/ImportONNXDefs.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,9 @@ necessary.
155155
It is not always needed to keep the code for an older version, which may be rewritten into the new
156156
operation. Thus, we just need to have the dialect definition, but not the code for inference or
157157
lowering.
158+
159+
# Adding Operations from not-builtin domains
160+
To add an operation from a not-builtin domain, it needs to be added to the `additional_op_version_dict` in gen_onnx_mlir.py. The key is the domain name and the value is the per-operation version dictionary.
161+
The new domain also needs to be added to the `domain_abrv_dict` in gen_onnx_mlir.py. The key is the domain name and the value is the abbreviation/prefix used in ONNX-MLIR for this domain.
162+
For operations from not-builtin domains, the operation definition specification needs to be manually provided.
163+
This can be done via custom TableGen records for the operations. See [/src/Dialect/ONNX/AdditionalONNXOps.td](../src/Dialect/ONNX/AdditionalONNXOps.td) for examples.

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 76 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ bool isDefaultDomain(std::string_view domain) {
6868
return domain.empty() || (domain == "ai.onnx");
6969
}
7070

71+
std::string canonicalizeDomain(std::string_view domain) {
72+
// Handle aliasing of "ai.onnx" and "". According to the onnx documentation,
73+
// the default domain is "ai.onnx", but in practice it seems like the
74+
// empty-string domain is used by default.
75+
return isDefaultDomain(domain) ? "" : std::string(domain);
76+
}
77+
7178
/// We consider opset < 6 is old. Users will see a warning if their model
7279
/// contains ops of old opset.
7380
constexpr int32_t MINIMUM_SUPPORTED_OPSET = 6;
@@ -86,7 +93,13 @@ template <class T>
8693
OpsetImportsMap GetOpsetImportsFromProto(const T &proto) {
8794
OpsetImportsMap opset_imports;
8895
for (const auto &opset_import : proto.opset_import()) {
89-
opset_imports[opset_import.domain()] = opset_import.version();
96+
const auto domain = canonicalizeDomain(opset_import.domain());
97+
const auto [iter_, inserted] =
98+
opset_imports.emplace(domain, opset_import.version());
99+
if (!inserted) {
100+
llvm::errs() << "Warning: Domain " << domain
101+
<< " found multiple times in opset imports.\n";
102+
}
90103
}
91104
return opset_imports;
92105
}
@@ -101,7 +114,8 @@ ModelLocalFunctionsMap GetModelLocalFunctions(const onnx::ModelProto &m) {
101114
for (const auto &function_proto : m.functions()) {
102115
model_local_functions_by_id.insert(
103116
{GetModelLocalFunctionsMapIdentifier(
104-
function_proto.domain(), function_proto.name()),
117+
canonicalizeDomain(function_proto.domain()),
118+
function_proto.name()),
105119
&function_proto});
106120
}
107121
return model_local_functions_by_id;
@@ -197,12 +211,12 @@ class FrontendGenImpl {
197211
ModuleOp module_;
198212
OpBuilder builder_;
199213

200-
// onnxop: list of versions supported by onnx-mlir for dialect
201-
std::map<std::string, std::vector<int>> op_dialect_version_map_;
202-
// onnxop: list of versions for dialect
203-
std::map<std::string, std::vector<int>> op_opsets_map_;
204-
// onnxop: the top version in third_part/onnx
205-
std::map<std::string, int> op_dialect_top_version_map_;
214+
// onnxop: list of versions supported by onnx-mlir for dialect, op
215+
std::map<std::string, std::map<std::string, std::vector<int>>>
216+
dialect_op_version_map_;
217+
// onnxop: list of versions for dialect, op
218+
std::map<std::string, std::map<std::string, std::vector<int>>>
219+
dialect_op_opsets_map_;
206220

207221
// mapping between string name and symbol
208222
ValueSymbolMapping frontend_symbols_;
@@ -214,7 +228,8 @@ class FrontendGenImpl {
214228
onnx_mlir::detail::FrontendGenImpl::*)(
215229
const onnx::NodeProto &, std::string & /*errorMessage*/);
216230

217-
std::map<std::string, ImportHandlerType> import_handler_map_;
231+
std::map<std::string, std::map<std::string, ImportHandlerType>>
232+
import_handler_map_;
218233

219234
// The total number of elements in all initializers. This value is a rough
220235
// counter of the number of parameters in a model.
@@ -848,23 +863,33 @@ class FrontendGenImpl {
848863
// Use the type map or types in input model to determine the data type of
849864
// output.
850865
std::vector<int> outputMap = T::getTypeMap();
866+
const bool shouldTakeShapeFromModelForCustomOp =
867+
isCustomOp &&
868+
(options_.useOnnxModelTypesForCustomOps || givenOutputTypes.empty());
851869
for (unsigned int i = 0; i < (unsigned int)node.output().size(); i++) {
852870
// Optional outputs using empty string.
853871
if (node.output()[i].empty()) {
854872
outputTypes.emplace_back(builder_.getNoneType());
855873
} else {
856-
if (options_.useOnnxModelTypes ||
857-
(isCustomOp && options_.useOnnxModelTypesForCustomOps)) {
874+
if (options_.useOnnxModelTypes || shouldTakeShapeFromModelForCustomOp) {
858875
auto onnxModelType = ConvertOnnxType(node.output(i), errorMessage);
859876
if (onnxModelType) {
860877
const auto ec = onnxModelType->getError();
861878
if (!ec) {
862-
outputTypes.emplace_back(*onnxModelType.value());
879+
Type outputType = *onnxModelType.value();
880+
if (!options_.useOnnxModelTypesForCustomOps &&
881+
!options_.useOnnxModelTypes) {
882+
if (auto shapedType = mlir::dyn_cast<ShapedType>(outputType)) {
883+
Type elementType = shapedType.getElementType();
884+
outputType = UnrankedTensorType::get(elementType);
885+
}
886+
}
887+
outputTypes.emplace_back(outputType);
863888
continue;
864889
}
865890
if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
866891
errorMessage +=
867-
"Failed to get type for '" + node.output(i) + "\n";
892+
"Failed to get type for '" + node.output(i) + "'\n";
868893
return ec;
869894
}
870895
llvm::errs() << "Warning: "
@@ -874,13 +899,19 @@ class FrontendGenImpl {
874899
}
875900
unsigned int j = i;
876901
// Variadic output is a single ODS result.
877-
if (variadicOut)
902+
if (variadicOut) {
878903
j = 0;
904+
}
879905
if (!givenOutputTypes.empty()) {
906+
assert(givenOutputTypes.size() > i &&
907+
"givenOutputTypes size is less than number of outputs");
880908
outputTypes.emplace_back(
881909
UnrankedTensorType::get(givenOutputTypes[i]));
882910
} else if (j < outputMap.size() && outputMap[j] >= MAX_NUM_TYPES) {
883911
// Mapping gives a connection with an input.
912+
assert(
913+
outputMap[j] - MAX_NUM_TYPES < static_cast<int>(inputs.size()) &&
914+
"output type mapping to input is out of range");
884915
Type inputType = inputs[outputMap[j] - MAX_NUM_TYPES].getType();
885916
if (mlir::isa<TensorType>(inputType)) {
886917
Type elementType =
@@ -1281,25 +1312,34 @@ class FrontendGenImpl {
12811312
return onnx::OpSchemaRegistry::Schema(node.op_type(), version, domain);
12821313
}
12831314

1284-
std::string GetImportVersionOfNode(const onnx::NodeProto &node) {
1285-
auto current_opset_it = opset_map_.find(node.domain());
1315+
std::string GetImportVersionOfNode(
1316+
const onnx::NodeProto &node, const std::string &domain) {
1317+
auto current_opset_it = opset_map_.find(domain);
12861318
if (current_opset_it == opset_map_.end())
12871319
return "";
12881320

12891321
const int current_opset = current_opset_it->second;
12901322

1323+
const auto op_version_map = dialect_op_version_map_.find(domain);
1324+
if (op_version_map == dialect_op_version_map_.end())
1325+
return "";
1326+
1327+
const auto op_opsets_map = dialect_op_opsets_map_.find(domain);
1328+
if (op_opsets_map == dialect_op_opsets_map_.end())
1329+
return "";
1330+
12911331
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX"
12921332
<< node.op_type() << " (" << node.name() << ")"
12931333
<< ", Opset: " << current_opset << "\n");
12941334

12951335
const auto supported_opset_list_it =
1296-
op_dialect_version_map_.find(node.op_type());
1297-
const auto opset_list_it = op_opsets_map_.find(node.op_type());
1336+
op_version_map->second.find(node.op_type());
1337+
const auto opset_list_it = op_opsets_map->second.find(node.op_type());
12981338

1299-
// Custom ops may not be present in op_dialect_version_map_. If no version
1339+
// Custom ops may not be present in op_version_map. If no version
13001340
// info is found, treat as unversioned (no renaming).
1301-
if (supported_opset_list_it == op_dialect_version_map_.end() ||
1302-
opset_list_it == op_opsets_map_.end())
1341+
if (supported_opset_list_it == op_version_map->second.end() ||
1342+
opset_list_it == op_opsets_map->second.end())
13031343
return "";
13041344

13051345
// To determine the opset version for a node/op:
@@ -1338,7 +1378,7 @@ class FrontendGenImpl {
13381378

13391379
// A new opset is added to onnx-mlir when it becomes incompatible.
13401380
// All opset newest than the last opset should use the last opset(version)
1341-
if (isDefaultDomain(node.domain()) &&
1381+
if (isDefaultDomain(domain) &&
13421382
upperRangeOfNewestValidOpsetVersion < supported_opset_list.back() &&
13431383
upperRangeOfNewestValidOpsetVersion < MINIMUM_SUPPORTED_OPSET)
13441384
llvm::errs() << "\nWarning: ONNX " << node.op_type()
@@ -1567,13 +1607,12 @@ class FrontendGenImpl {
15671607
auto mlirAttr = builder_.getStringAttr(funcName);
15681608
auto funcAttr = builder_.getNamedAttr("function_name", mlirAttr);
15691609
attributes.push_back(funcAttr);
1570-
auto domainAttr = builder_.getNamedAttr(
1571-
"domain_name", builder_.getStringAttr(node.domain()));
1610+
auto domainAttr = builder_.getNamedAttr("domain_name",
1611+
builder_.getStringAttr(canonicalizeDomain(node.domain())));
15721612
attributes.push_back(domainAttr);
1573-
int nIn = 0;
1574-
int nOut = 0;
1613+
const int nIn = ONNXCustomOp::getNumberOfOperands();
1614+
const int nOut = ONNXCustomOp::getNumberOfResults();
15751615
getNodeInputs(node, inputs);
1576-
nOut = node.output().size();
15771616
std::vector<Type> givenOutputTypes;
15781617

15791618
// We lack a way of specifying import behavior for custom domains. For now
@@ -1611,13 +1650,16 @@ class FrontendGenImpl {
16111650

16121651
[[nodiscard]] std::error_code ImportNode(
16131652
const onnx::NodeProto &node, std::string &errorMessage) {
1614-
if (isDefaultDomain(node.domain()) || (node.domain() == "ai.onnx.ml") ||
1615-
(node.domain() == "ai.onnx.preview.training")) {
1616-
std::string opName = node.op_type() + GetImportVersionOfNode(node);
1617-
auto handler = import_handler_map_.find(opName);
1653+
const std::string domain = canonicalizeDomain(node.domain());
1654+
1655+
const std::string opName =
1656+
node.op_type() + GetImportVersionOfNode(node, domain);
1657+
auto domainIt = import_handler_map_.find(domain);
1658+
if (domainIt != import_handler_map_.end()) {
1659+
auto handler = domainIt->second.find(opName);
16181660
std::vector<std::string> funcs = options_.functionsToDecompose;
1619-
if (!(std::find(funcs.begin(), funcs.end(), opName) != funcs.end())) {
1620-
if (handler != import_handler_map_.end()) {
1661+
if (!llvm::is_contained(funcs, opName)) {
1662+
if (handler != domainIt->second.end()) {
16211663
// It's a regular op with a registered handler.
16221664
return (this->*(handler->second))(node, errorMessage);
16231665
}
@@ -1633,7 +1675,7 @@ class FrontendGenImpl {
16331675
}
16341676

16351677
auto model_function = in_model_functions_.find(
1636-
GetModelLocalFunctionsMapIdentifier(node.domain(), node.op_type()));
1678+
GetModelLocalFunctionsMapIdentifier(domain, node.op_type()));
16371679
if (model_function != in_model_functions_.end()) {
16381680
return ImportFunctionCallNode(
16391681
node, /*schema=*/nullptr, model_function->second, errorMessage);

0 commit comments

Comments
 (0)