Skip to content

Commit 2be7757

Browse files
committed
Allow the import of not-builtin domains and add POC support for com.amd.quark
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent bc9ebf3 commit 2be7757

File tree

4 files changed

+1022
-963
lines changed

4 files changed

+1022
-963
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,12 @@ class FrontendGenImpl {
197197
ModuleOp module_;
198198
OpBuilder builder_;
199199

200-
// onnxop: list of versions supported by onnx-mlir for dialect
201-
std::map<std::string, std::vector<int>> op_dialect_version_map_;
202-
// onnxop: list of versions for dialect
203-
std::map<std::string, std::vector<int>> op_opsets_map_;
204-
// onnxop: the top version in third_part/onnx
205-
std::map<std::string, int> op_dialect_top_version_map_;
200+
// onnxop: list of versions supported by onnx-mlir for dialect, op
201+
std::map<std::string, std::map<std::string, std::vector<int>>>
202+
dialect_op_version_map_;
203+
// onnxop: list of versions for dialect, op
204+
std::map<std::string, std::map<std::string, std::vector<int>>>
205+
dialect_op_opsets_map_;
206206

207207
// mapping between string name and symbol
208208
ValueSymbolMapping frontend_symbols_;
@@ -214,7 +214,8 @@ class FrontendGenImpl {
214214
onnx_mlir::detail::FrontendGenImpl::*)(
215215
const onnx::NodeProto &, std::string & /*errorMessage*/);
216216

217-
std::map<std::string, ImportHandlerType> import_handler_map_;
217+
std::map<std::string, std::map<std::string, ImportHandlerType>>
218+
import_handler_map_;
218219

219220
// The total number of elements in all initializers. This value is a rough
220221
// counter of the number of parameters in a model.
@@ -1288,18 +1289,26 @@ class FrontendGenImpl {
12881289

12891290
const int current_opset = current_opset_it->second;
12901291

1292+
const auto op_version_map = dialect_op_version_map_.find(node.domain());
1293+
if (op_version_map == dialect_op_version_map_.end())
1294+
return "";
1295+
1296+
const auto op_opsets_map = dialect_op_opsets_map_.find(node.domain());
1297+
if (op_opsets_map == dialect_op_opsets_map_.end())
1298+
return "";
1299+
12911300
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX"
12921301
<< node.op_type() << " (" << node.name() << ")"
12931302
<< ", Opset: " << current_opset << "\n");
12941303

12951304
const auto supported_opset_list_it =
1296-
op_dialect_version_map_.find(node.op_type());
1297-
const auto opset_list_it = op_opsets_map_.find(node.op_type());
1305+
op_version_map->second.find(node.op_type());
1306+
const auto opset_list_it = op_opsets_map->second.find(node.op_type());
12981307

1299-
// Custom ops may not be present in op_dialect_version_map_. If no version
1308+
// Custom ops may not be present in op_version_map. If no version
13001309
// info is found, treat as unversioned (no renaming).
1301-
if (supported_opset_list_it == op_dialect_version_map_.end() ||
1302-
opset_list_it == op_opsets_map_.end())
1310+
if (supported_opset_list_it == op_version_map->second.end() ||
1311+
opset_list_it == op_opsets_map->second.end())
13031312
return "";
13041313

13051314
// To determine the opset version for a node/op:
@@ -1611,13 +1620,14 @@ class FrontendGenImpl {
16111620

16121621
[[nodiscard]] std::error_code ImportNode(
16131622
const onnx::NodeProto &node, std::string &errorMessage) {
1614-
if (isDefaultDomain(node.domain()) || (node.domain() == "ai.onnx.ml") ||
1615-
(node.domain() == "ai.onnx.preview.training")) {
1616-
std::string opName = node.op_type() + GetImportVersionOfNode(node);
1617-
auto handler = import_handler_map_.find(opName);
1623+
1624+
std::string opName = node.op_type() + GetImportVersionOfNode(node);
1625+
auto domainIt = import_handler_map_.find(node.domain());
1626+
if (domainIt != import_handler_map_.end()) {
1627+
auto handler = domainIt->second.find(opName);
16181628
std::vector<std::string> funcs = options_.functionsToDecompose;
1619-
if (!(std::find(funcs.begin(), funcs.end(), opName) != funcs.end())) {
1620-
if (handler != import_handler_map_.end()) {
1629+
if (!llvm::is_contained(funcs, opName)) {
1630+
if (handler != domainIt->second.end()) {
16211631
// It's a regular op with a registered handler.
16221632
return (this->*(handler->second))(node, errorMessage);
16231633
}

0 commit comments

Comments
 (0)