Skip to content

Commit 17c0b81

Browse files
authored
Merge pull request #475 from Xilinx/jrickert.mx6
Allow the import of not-builtin domains and add support for AMD Quarks BFPQuantizeDequantizeOp
2 parents ac10142 + cdd9f56 commit 17c0b81

File tree

13 files changed

+1253
-964
lines changed

13 files changed

+1253
-964
lines changed

docs/Dialects/onnx.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,45 @@
11
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
2+
### `onnx.AMDQuarkBFPQuantizeDequantizeOp` (AMDQuarkBFPQuantizeDequantizeOp)
3+
4+
_BFPQuantizeDequantize_
5+
6+
Block Floating Point (BFP) groups numbers (e.g., tensors, arrays) into blocks, where each block shares a common exponent, and the values in the block are represented with individual mantissas (and the sign bit). This approach offers the performance and speed of 8-bit operations while bringing the precision closer to 16-bit operations.
7+
8+
MicroeXponents (MX) extends the concept of BFP by introducing two levels of exponents: shared exponents for entire blocks and micro exponents for finer-grained sub-blocks. This two-level approach enables more precise scaling of individual elements within a block, reducing quantization error and improving the representational range. The paper https://arxiv.org/abs/2302.08007 introduces three specific formats: MX4, MX6 and MX9, which have different bits of mantissa.
9+
10+
This operator converts floating-point values (typically 32-bit floating-point numbers) into BFP or MX values, then convert them back. It approximates the Quantize-Dequantize process and introduces quantization errors.
11+
12+
Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>`, `SameOperandsAndResultElementType`
13+
14+
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
15+
16+
Effects: `MemoryEffects::Effect{}`
17+
18+
#### Attributes:
19+
20+
<table>
21+
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
22+
<tr><td><code>bfp_method</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
23+
<tr><td><code>axis</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
24+
<tr><td><code>bit_width</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
25+
<tr><td><code>block_size</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
26+
<tr><td><code>rounding_mode</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
27+
<tr><td><code>sub_block_size</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
28+
<tr><td><code>sub_block_shift_bits</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
29+
</table>
30+
31+
#### Operands:
32+
33+
| Operand | Description |
34+
| :-----: | ----------- |
35+
| `X` | tensor of 32-bit float values
36+
37+
#### Results:
38+
39+
| Result | Description |
40+
| :----: | ----------- |
41+
| `Y` | tensor of 32-bit float values
42+
243
### `onnx.Abs` (ONNXAbsOp)
344

445
_ONNX Abs operation_

docs/ImportONNXDefs.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,9 @@ necessary.
155155
It is not always needed to keep the code for an older version, which may be rewritten into the new
156156
operation. Thus, we just need to have the dialect definition, but not the code for inference or
157157
lowering.
158+
159+
# Adding Operations from not-builtin domains
160+
To add an operation from a not-builtin domain, it needs to be added to the `additional_op_version_dict` in gen_onnx_mlir.py. The key is the domain name and the value is the per-operation version dictionary.
161+
The new domain also needs to be added to the `domain_abrv_dict` in gen_onnx_mlir.py. The key is the domain name and the value is the abbreviation/prefix used in ONNX-MLIR for this domain.
162+
For operations from not-builtin domains, the operation definition specification needs to be manually provided.
163+
This can be done via custom TableGen records for the operations. See [/src/Dialect/ONNX/AdditionalONNXOps.td](../src/Dialect/ONNX/AdditionalONNXOps.td) for examples.

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,12 @@ class FrontendGenImpl {
197197
ModuleOp module_;
198198
OpBuilder builder_;
199199

200-
// onnxop: list of versions supported by onnx-mlir for dialect
201-
std::map<std::string, std::vector<int>> op_dialect_version_map_;
202-
// onnxop: list of versions for dialect
203-
std::map<std::string, std::vector<int>> op_opsets_map_;
204-
// onnxop: the top version in third_part/onnx
205-
std::map<std::string, int> op_dialect_top_version_map_;
200+
// onnxop: list of versions supported by onnx-mlir for dialect, op
201+
std::map<std::string, std::map<std::string, std::vector<int>>>
202+
dialect_op_version_map_;
203+
// onnxop: list of versions for dialect, op
204+
std::map<std::string, std::map<std::string, std::vector<int>>>
205+
dialect_op_opsets_map_;
206206

207207
// mapping between string name and symbol
208208
ValueSymbolMapping frontend_symbols_;
@@ -214,7 +214,8 @@ class FrontendGenImpl {
214214
onnx_mlir::detail::FrontendGenImpl::*)(
215215
const onnx::NodeProto &, std::string & /*errorMessage*/);
216216

217-
std::map<std::string, ImportHandlerType> import_handler_map_;
217+
std::map<std::string, std::map<std::string, ImportHandlerType>>
218+
import_handler_map_;
218219

219220
// The total number of elements in all initializers. This value is a rough
220221
// counter of the number of parameters in a model.
@@ -1304,18 +1305,26 @@ class FrontendGenImpl {
13041305

13051306
const int current_opset = current_opset_it->second;
13061307

1308+
const auto op_version_map = dialect_op_version_map_.find(node.domain());
1309+
if (op_version_map == dialect_op_version_map_.end())
1310+
return "";
1311+
1312+
const auto op_opsets_map = dialect_op_opsets_map_.find(node.domain());
1313+
if (op_opsets_map == dialect_op_opsets_map_.end())
1314+
return "";
1315+
13071316
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX"
13081317
<< node.op_type() << " (" << node.name() << ")"
13091318
<< ", Opset: " << current_opset << "\n");
13101319

13111320
const auto supported_opset_list_it =
1312-
op_dialect_version_map_.find(node.op_type());
1313-
const auto opset_list_it = op_opsets_map_.find(node.op_type());
1321+
op_version_map->second.find(node.op_type());
1322+
const auto opset_list_it = op_opsets_map->second.find(node.op_type());
13141323

1315-
// Custom ops may not be present in op_dialect_version_map_. If no version
1324+
// Custom ops may not be present in op_version_map. If no version
13161325
// info is found, treat as unversioned (no renaming).
1317-
if (supported_opset_list_it == op_dialect_version_map_.end() ||
1318-
opset_list_it == op_opsets_map_.end())
1326+
if (supported_opset_list_it == op_version_map->second.end() ||
1327+
opset_list_it == op_opsets_map->second.end())
13191328
return "";
13201329

13211330
// To determine the opset version for a node/op:
@@ -1626,13 +1635,14 @@ class FrontendGenImpl {
16261635

16271636
[[nodiscard]] std::error_code ImportNode(
16281637
const onnx::NodeProto &node, std::string &errorMessage) {
1629-
if (isDefaultDomain(node.domain()) || (node.domain() == "ai.onnx.ml") ||
1630-
(node.domain() == "ai.onnx.preview.training")) {
1631-
std::string opName = node.op_type() + GetImportVersionOfNode(node);
1632-
auto handler = import_handler_map_.find(opName);
1638+
1639+
std::string opName = node.op_type() + GetImportVersionOfNode(node);
1640+
auto domainIt = import_handler_map_.find(node.domain());
1641+
if (domainIt != import_handler_map_.end()) {
1642+
auto handler = domainIt->second.find(opName);
16331643
std::vector<std::string> funcs = options_.functionsToDecompose;
1634-
if (!(std::find(funcs.begin(), funcs.end(), opName) != funcs.end())) {
1635-
if (handler != import_handler_map_.end()) {
1644+
if (!llvm::is_contained(funcs, opName)) {
1645+
if (handler != domainIt->second.end()) {
16361646
// It's a regular op with a registered handler.
16371647
return (this->*(handler->second))(node, errorMessage);
16381648
}

0 commit comments

Comments
 (0)