@@ -68,6 +68,13 @@ bool isDefaultDomain(std::string_view domain) {
6868 return domain.empty () || (domain == " ai.onnx" );
6969}
7070
71+ std::string canonicalizeDomain (std::string_view domain) {
72+ // Handle aliasing of "ai.onnx" and "". According to the onnx documentation,
73+ // the default domain is "ai.onnx", but in practice it seems like the
74+ // empty-string domain is used by default.
75+ return isDefaultDomain (domain) ? " " : std::string (domain);
76+ }
77+
7178// / We consider opset < 6 is old. Users will see a warning if their model
7279// / contains ops of old opset.
7380constexpr int32_t MINIMUM_SUPPORTED_OPSET = 6 ;
@@ -86,7 +93,13 @@ template <class T>
8693OpsetImportsMap GetOpsetImportsFromProto (const T &proto) {
8794 OpsetImportsMap opset_imports;
8895 for (const auto &opset_import : proto.opset_import ()) {
89- opset_imports[opset_import.domain ()] = opset_import.version ();
96+ const auto domain = canonicalizeDomain (opset_import.domain ());
97+ const auto [iter_, inserted] =
98+ opset_imports.emplace (domain, opset_import.version ());
99+ if (!inserted) {
100+ llvm::errs () << " Warning: Domain " << domain
101+ << " found multiple times in opset imports.\n " ;
102+ }
90103 }
91104 return opset_imports;
92105}
@@ -101,7 +114,8 @@ ModelLocalFunctionsMap GetModelLocalFunctions(const onnx::ModelProto &m) {
101114 for (const auto &function_proto : m.functions ()) {
102115 model_local_functions_by_id.insert (
103116 {GetModelLocalFunctionsMapIdentifier (
104- function_proto.domain (), function_proto.name ()),
117+ canonicalizeDomain (function_proto.domain ()),
118+ function_proto.name ()),
105119 &function_proto});
106120 }
107121 return model_local_functions_by_id;
@@ -1298,18 +1312,19 @@ class FrontendGenImpl {
12981312 return onnx::OpSchemaRegistry::Schema (node.op_type (), version, domain);
12991313 }
13001314
1301- std::string GetImportVersionOfNode (const onnx::NodeProto &node) {
1302- auto current_opset_it = opset_map_.find (node.domain ());
1315+ std::string GetImportVersionOfNode (
1316+ const onnx::NodeProto &node, const std::string &domain) {
1317+ auto current_opset_it = opset_map_.find (domain);
13031318 if (current_opset_it == opset_map_.end ())
13041319 return " " ;
13051320
13061321 const int current_opset = current_opset_it->second ;
13071322
1308- const auto op_version_map = dialect_op_version_map_.find (node. domain () );
1323+ const auto op_version_map = dialect_op_version_map_.find (domain);
13091324 if (op_version_map == dialect_op_version_map_.end ())
13101325 return " " ;
13111326
1312- const auto op_opsets_map = dialect_op_opsets_map_.find (node. domain () );
1327+ const auto op_opsets_map = dialect_op_opsets_map_.find (domain);
13131328 if (op_opsets_map == dialect_op_opsets_map_.end ())
13141329 return " " ;
13151330
@@ -1363,7 +1378,7 @@ class FrontendGenImpl {
13631378
13641379 // A new opset is added to onnx-mlir when it becomes incompatible.
13651380 // All opset newest than the last opset should use the last opset(version)
1366- if (isDefaultDomain (node. domain () ) &&
1381+ if (isDefaultDomain (domain) &&
13671382 upperRangeOfNewestValidOpsetVersion < supported_opset_list.back () &&
13681383 upperRangeOfNewestValidOpsetVersion < MINIMUM_SUPPORTED_OPSET)
13691384 llvm::errs () << " \n Warning: ONNX " << node.op_type ()
@@ -1592,8 +1607,8 @@ class FrontendGenImpl {
15921607 auto mlirAttr = builder_.getStringAttr (funcName);
15931608 auto funcAttr = builder_.getNamedAttr (" function_name" , mlirAttr);
15941609 attributes.push_back (funcAttr);
1595- auto domainAttr = builder_.getNamedAttr (
1596- " domain_name " , builder_.getStringAttr (node.domain ()));
1610+ auto domainAttr = builder_.getNamedAttr (" domain_name " ,
1611+ builder_.getStringAttr (canonicalizeDomain ( node.domain () )));
15971612 attributes.push_back (domainAttr);
15981613 const int nIn = ONNXCustomOp::getNumberOfOperands ();
15991614 const int nOut = ONNXCustomOp::getNumberOfResults ();
@@ -1635,9 +1650,11 @@ class FrontendGenImpl {
16351650
16361651 [[nodiscard]] std::error_code ImportNode (
16371652 const onnx::NodeProto &node, std::string &errorMessage) {
1653+ const std::string domain = canonicalizeDomain (node.domain ());
16381654
1639- std::string opName = node.op_type () + GetImportVersionOfNode (node);
1640- auto domainIt = import_handler_map_.find (node.domain ());
1655+ const std::string opName =
1656+ node.op_type () + GetImportVersionOfNode (node, domain);
1657+ auto domainIt = import_handler_map_.find (domain);
16411658 if (domainIt != import_handler_map_.end ()) {
16421659 auto handler = domainIt->second .find (opName);
16431660 std::vector<std::string> funcs = options_.functionsToDecompose ;
@@ -1658,7 +1675,7 @@ class FrontendGenImpl {
16581675 }
16591676
16601677 auto model_function = in_model_functions_.find (
1661- GetModelLocalFunctionsMapIdentifier (node. domain () , node.op_type ()));
1678+ GetModelLocalFunctionsMapIdentifier (domain, node.op_type ()));
16621679 if (model_function != in_model_functions_.end ()) {
16631680 return ImportFunctionCallNode (
16641681 node, /* schema=*/ nullptr , model_function->second , errorMessage);
0 commit comments