@@ -197,12 +197,12 @@ class FrontendGenImpl {
197197 ModuleOp module_;
198198 OpBuilder builder_;
199199
200- // onnxop: list of versions supported by onnx-mlir for dialect
201- std::map<std::string, std::vector<int >> op_dialect_version_map_;
202- // onnxop: list of versions for dialect
203- std::map<std::string, std::vector< int >> op_opsets_map_;
204- // onnxop: the top version in third_part/onnx
205- std::map<std::string, int > op_dialect_top_version_map_ ;
200+ // onnxop: list of versions supported by onnx-mlir for dialect, op
201+ std::map<std::string, std::map<std::string, std:: vector<int >>>
202+ dialect_op_version_map_;
203+ // onnxop: list of versions for dialect, op
204+ std::map<std::string, std::map<std::string, std::vector< int >>>
205+ dialect_op_opsets_map_ ;
206206
207207 // mapping between string name and symbol
208208 ValueSymbolMapping frontend_symbols_;
@@ -214,7 +214,8 @@ class FrontendGenImpl {
214214 onnx_mlir::detail::FrontendGenImpl::*)(
215215 const onnx::NodeProto &, std::string & /* errorMessage*/ );
216216
217- std::map<std::string, ImportHandlerType> import_handler_map_;
217+ std::map<std::string, std::map<std::string, ImportHandlerType>>
218+ import_handler_map_;
218219
219220 // The total number of elements in all initializers. This value is a rough
220221 // counter of the number of parameters in a model.
@@ -1288,18 +1289,26 @@ class FrontendGenImpl {
12881289
12891290 const int current_opset = current_opset_it->second ;
12901291
1292+ const auto op_version_map = dialect_op_version_map_.find (node.domain ());
1293+ if (op_version_map == dialect_op_version_map_.end ())
1294+ return " " ;
1295+
1296+ const auto op_opsets_map = dialect_op_opsets_map_.find (node.domain ());
1297+ if (op_opsets_map == dialect_op_opsets_map_.end ())
1298+ return " " ;
1299+
12911300 LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " : Importing ONNX"
12921301 << node.op_type () << " (" << node.name () << " )"
12931302 << " , Opset: " << current_opset << " \n " );
12941303
12951304 const auto supported_opset_list_it =
1296- op_dialect_version_map_ .find (node.op_type ());
1297- const auto opset_list_it = op_opsets_map_ .find (node.op_type ());
1305+ op_version_map-> second .find (node.op_type ());
1306+ const auto opset_list_it = op_opsets_map-> second .find (node.op_type ());
12981307
1299- // Custom ops may not be present in op_dialect_version_map_ . If no version
1308+ // Custom ops may not be present in op_version_map . If no version
13001309 // info is found, treat as unversioned (no renaming).
1301- if (supported_opset_list_it == op_dialect_version_map_ .end () ||
1302- opset_list_it == op_opsets_map_ .end ())
1310+ if (supported_opset_list_it == op_version_map-> second .end () ||
1311+ opset_list_it == op_opsets_map-> second .end ())
13031312 return " " ;
13041313
13051314 // To determine the opset version for a node/op:
@@ -1611,13 +1620,14 @@ class FrontendGenImpl {
16111620
16121621 [[nodiscard]] std::error_code ImportNode (
16131622 const onnx::NodeProto &node, std::string &errorMessage) {
1614- if (isDefaultDomain (node.domain ()) || (node.domain () == " ai.onnx.ml" ) ||
1615- (node.domain () == " ai.onnx.preview.training" )) {
1616- std::string opName = node.op_type () + GetImportVersionOfNode (node);
1617- auto handler = import_handler_map_.find (opName);
1623+
1624+ std::string opName = node.op_type () + GetImportVersionOfNode (node);
1625+ auto domainIt = import_handler_map_.find (node.domain ());
1626+ if (domainIt != import_handler_map_.end ()) {
1627+ auto handler = domainIt->second .find (opName);
16181628 std::vector<std::string> funcs = options_.functionsToDecompose ;
1619- if (!( std::find (funcs. begin (), funcs. end (), opName) != funcs. end () )) {
1620- if (handler != import_handler_map_ .end ()) {
1629+ if (!llvm::is_contained (funcs, opName)) {
1630+ if (handler != domainIt-> second .end ()) {
16211631 // It's a regular op with a registered handler.
16221632 return (this ->*(handler->second ))(node, errorMessage);
16231633 }
0 commit comments