@@ -68,6 +68,13 @@ bool isDefaultDomain(std::string_view domain) {
6868 return domain.empty () || (domain == " ai.onnx" );
6969}
7070
71+ std::string canonicalizeDomain (std::string_view domain) {
72+ // Handle aliasing of "ai.onnx" and "". According to the onnx documentation,
73+ // the default domain is "ai.onnx", but in practice it seems like the
74+ // empty-string domain is used by default.
75+ return isDefaultDomain (domain) ? " " : std::string (domain);
76+ }
77+
7178// / We consider opset < 6 is old. Users will see a warning if their model
7279// / contains ops of old opset.
7380constexpr int32_t MINIMUM_SUPPORTED_OPSET = 6 ;
@@ -86,7 +93,13 @@ template <class T>
8693OpsetImportsMap GetOpsetImportsFromProto (const T &proto) {
8794 OpsetImportsMap opset_imports;
8895 for (const auto &opset_import : proto.opset_import ()) {
89- opset_imports[opset_import.domain ()] = opset_import.version ();
96+ const auto domain = canonicalizeDomain (opset_import.domain ());
97+ const auto [iter_, inserted] =
98+ opset_imports.emplace (domain, opset_import.version ());
99+ if (!inserted) {
100+ llvm::errs () << " Warning: Domain " << domain
101+ << " found multiple times in opset imports.\n " ;
102+ }
90103 }
91104 return opset_imports;
92105}
@@ -101,7 +114,8 @@ ModelLocalFunctionsMap GetModelLocalFunctions(const onnx::ModelProto &m) {
101114 for (const auto &function_proto : m.functions ()) {
102115 model_local_functions_by_id.insert (
103116 {GetModelLocalFunctionsMapIdentifier (
104- function_proto.domain (), function_proto.name ()),
117+ canonicalizeDomain (function_proto.domain ()),
118+ function_proto.name ()),
105119 &function_proto});
106120 }
107121 return model_local_functions_by_id;
@@ -197,12 +211,12 @@ class FrontendGenImpl {
197211 ModuleOp module_;
198212 OpBuilder builder_;
199213
200- // onnxop: list of versions supported by onnx-mlir for dialect
201- std::map<std::string, std::vector<int >> op_dialect_version_map_;
202- // onnxop: list of versions for dialect
203- std::map<std::string, std::vector< int >> op_opsets_map_;
204- // onnxop: the top version in third_part/onnx
205- std::map<std::string, int > op_dialect_top_version_map_ ;
214+ // onnxop: list of versions supported by onnx-mlir for dialect, op
215+ std::map<std::string, std::map<std::string, std:: vector<int >>>
216+ dialect_op_version_map_;
217+ // onnxop: list of versions for dialect, op
218+ std::map<std::string, std::map<std::string, std::vector< int >>>
219+ dialect_op_opsets_map_ ;
206220
207221 // mapping between string name and symbol
208222 ValueSymbolMapping frontend_symbols_;
@@ -214,7 +228,8 @@ class FrontendGenImpl {
214228 onnx_mlir::detail::FrontendGenImpl::*)(
215229 const onnx::NodeProto &, std::string & /* errorMessage*/ );
216230
217- std::map<std::string, ImportHandlerType> import_handler_map_;
231+ std::map<std::string, std::map<std::string, ImportHandlerType>>
232+ import_handler_map_;
218233
219234 // The total number of elements in all initializers. This value is a rough
220235 // counter of the number of parameters in a model.
@@ -848,23 +863,33 @@ class FrontendGenImpl {
848863 // Use the type map or types in input model to determine the data type of
849864 // output.
850865 std::vector<int > outputMap = T::getTypeMap ();
866+ const bool shouldTakeShapeFromModelForCustomOp =
867+ isCustomOp &&
868+ (options_.useOnnxModelTypesForCustomOps || givenOutputTypes.empty ());
851869 for (unsigned int i = 0 ; i < (unsigned int )node.output ().size (); i++) {
852870 // Optional outputs using empty string.
853871 if (node.output ()[i].empty ()) {
854872 outputTypes.emplace_back (builder_.getNoneType ());
855873 } else {
856- if (options_.useOnnxModelTypes ||
857- (isCustomOp && options_.useOnnxModelTypesForCustomOps )) {
874+ if (options_.useOnnxModelTypes || shouldTakeShapeFromModelForCustomOp) {
858875 auto onnxModelType = ConvertOnnxType (node.output (i), errorMessage);
859876 if (onnxModelType) {
860877 const auto ec = onnxModelType->getError ();
861878 if (!ec) {
862- outputTypes.emplace_back (*onnxModelType.value ());
879+ Type outputType = *onnxModelType.value ();
880+ if (!options_.useOnnxModelTypesForCustomOps &&
881+ !options_.useOnnxModelTypes ) {
882+ if (auto shapedType = mlir::dyn_cast<ShapedType>(outputType)) {
883+ Type elementType = shapedType.getElementType ();
884+ outputType = UnrankedTensorType::get (elementType);
885+ }
886+ }
887+ outputTypes.emplace_back (outputType);
863888 continue ;
864889 }
865890 if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
866891 errorMessage +=
867- " Failed to get type for '" + node.output (i) + " \n " ;
892+ " Failed to get type for '" + node.output (i) + " ' \n " ;
868893 return ec;
869894 }
870895 llvm::errs () << " Warning: "
@@ -874,13 +899,19 @@ class FrontendGenImpl {
874899 }
875900 unsigned int j = i;
876901 // Variadic output is a single ODS result.
877- if (variadicOut)
902+ if (variadicOut) {
878903 j = 0 ;
904+ }
879905 if (!givenOutputTypes.empty ()) {
906+ assert (givenOutputTypes.size () > i &&
907+ " givenOutputTypes size is less than number of outputs" );
880908 outputTypes.emplace_back (
881909 UnrankedTensorType::get (givenOutputTypes[i]));
882910 } else if (j < outputMap.size () && outputMap[j] >= MAX_NUM_TYPES) {
883911 // Mapping gives a connection with an input.
912+ assert (
913+ outputMap[j] - MAX_NUM_TYPES < static_cast <int >(inputs.size ()) &&
914+ " output type mapping to input is out of range" );
884915 Type inputType = inputs[outputMap[j] - MAX_NUM_TYPES].getType ();
885916 if (mlir::isa<TensorType>(inputType)) {
886917 Type elementType =
@@ -1281,25 +1312,34 @@ class FrontendGenImpl {
12811312 return onnx::OpSchemaRegistry::Schema (node.op_type (), version, domain);
12821313 }
12831314
1284- std::string GetImportVersionOfNode (const onnx::NodeProto &node) {
1285- auto current_opset_it = opset_map_.find (node.domain ());
1315+ std::string GetImportVersionOfNode (
1316+ const onnx::NodeProto &node, const std::string &domain) {
1317+ auto current_opset_it = opset_map_.find (domain);
12861318 if (current_opset_it == opset_map_.end ())
12871319 return " " ;
12881320
12891321 const int current_opset = current_opset_it->second ;
12901322
1323+ const auto op_version_map = dialect_op_version_map_.find (domain);
1324+ if (op_version_map == dialect_op_version_map_.end ())
1325+ return " " ;
1326+
1327+ const auto op_opsets_map = dialect_op_opsets_map_.find (domain);
1328+ if (op_opsets_map == dialect_op_opsets_map_.end ())
1329+ return " " ;
1330+
12911331 LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " : Importing ONNX"
12921332 << node.op_type () << " (" << node.name () << " )"
12931333 << " , Opset: " << current_opset << " \n " );
12941334
12951335 const auto supported_opset_list_it =
1296- op_dialect_version_map_ .find (node.op_type ());
1297- const auto opset_list_it = op_opsets_map_ .find (node.op_type ());
1336+ op_version_map-> second .find (node.op_type ());
1337+ const auto opset_list_it = op_opsets_map-> second .find (node.op_type ());
12981338
1299- // Custom ops may not be present in op_dialect_version_map_ . If no version
1339+ // Custom ops may not be present in op_version_map . If no version
13001340 // info is found, treat as unversioned (no renaming).
1301- if (supported_opset_list_it == op_dialect_version_map_ .end () ||
1302- opset_list_it == op_opsets_map_ .end ())
1341+ if (supported_opset_list_it == op_version_map-> second .end () ||
1342+ opset_list_it == op_opsets_map-> second .end ())
13031343 return " " ;
13041344
13051345 // To determine the opset version for a node/op:
@@ -1338,7 +1378,7 @@ class FrontendGenImpl {
13381378
13391379 // A new opset is added to onnx-mlir when it becomes incompatible.
13401380 // All opset newest than the last opset should use the last opset(version)
1341- if (isDefaultDomain (node. domain () ) &&
1381+ if (isDefaultDomain (domain) &&
13421382 upperRangeOfNewestValidOpsetVersion < supported_opset_list.back () &&
13431383 upperRangeOfNewestValidOpsetVersion < MINIMUM_SUPPORTED_OPSET)
13441384 llvm::errs () << " \n Warning: ONNX " << node.op_type ()
@@ -1567,13 +1607,12 @@ class FrontendGenImpl {
15671607 auto mlirAttr = builder_.getStringAttr (funcName);
15681608 auto funcAttr = builder_.getNamedAttr (" function_name" , mlirAttr);
15691609 attributes.push_back (funcAttr);
1570- auto domainAttr = builder_.getNamedAttr (
1571- " domain_name " , builder_.getStringAttr (node.domain ()));
1610+ auto domainAttr = builder_.getNamedAttr (" domain_name " ,
1611+ builder_.getStringAttr (canonicalizeDomain ( node.domain () )));
15721612 attributes.push_back (domainAttr);
1573- int nIn = 0 ;
1574- int nOut = 0 ;
1613+ const int nIn = ONNXCustomOp::getNumberOfOperands () ;
1614+ const int nOut = ONNXCustomOp::getNumberOfResults () ;
15751615 getNodeInputs (node, inputs);
1576- nOut = node.output ().size ();
15771616 std::vector<Type> givenOutputTypes;
15781617
15791618 // We lack a way of specifying import behavior for custom domains. For now
@@ -1611,13 +1650,16 @@ class FrontendGenImpl {
16111650
16121651 [[nodiscard]] std::error_code ImportNode (
16131652 const onnx::NodeProto &node, std::string &errorMessage) {
1614- if (isDefaultDomain (node.domain ()) || (node.domain () == " ai.onnx.ml" ) ||
1615- (node.domain () == " ai.onnx.preview.training" )) {
1616- std::string opName = node.op_type () + GetImportVersionOfNode (node);
1617- auto handler = import_handler_map_.find (opName);
1653+ const std::string domain = canonicalizeDomain (node.domain ());
1654+
1655+ const std::string opName =
1656+ node.op_type () + GetImportVersionOfNode (node, domain);
1657+ auto domainIt = import_handler_map_.find (domain);
1658+ if (domainIt != import_handler_map_.end ()) {
1659+ auto handler = domainIt->second .find (opName);
16181660 std::vector<std::string> funcs = options_.functionsToDecompose ;
1619- if (!( std::find (funcs. begin (), funcs. end (), opName) != funcs. end () )) {
1620- if (handler != import_handler_map_ .end ()) {
1661+ if (!llvm::is_contained (funcs, opName)) {
1662+ if (handler != domainIt-> second .end ()) {
16211663 // It's a regular op with a registered handler.
16221664 return (this ->*(handler->second ))(node, errorMessage);
16231665 }
@@ -1633,7 +1675,7 @@ class FrontendGenImpl {
16331675 }
16341676
16351677 auto model_function = in_model_functions_.find (
1636- GetModelLocalFunctionsMapIdentifier (node. domain () , node.op_type ()));
1678+ GetModelLocalFunctionsMapIdentifier (domain, node.op_type ()));
16371679 if (model_function != in_model_functions_.end ()) {
16381680 return ImportFunctionCallNode (
16391681 node, /* schema=*/ nullptr , model_function->second , errorMessage);
0 commit comments