21
21
from tensorflow .tools .graph_transforms import TransformGraph
22
22
23
23
import tf2onnx
24
- from tf2onnx import constants , schemas , utils
24
+ from tf2onnx import constants , custom , schemas , utils
25
25
from tf2onnx .function import * # pylint: disable=wildcard-import
26
26
from tf2onnx .graph import Graph
27
27
from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
@@ -2322,13 +2322,28 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
2322
2322
mapped_op = collections .Counter ()
2323
2323
unmapped_op = collections .Counter ()
2324
2324
2325
- # create ops mapping for the desired opset
2325
+ # create ops mapping for the desired opsets
2326
2326
ops_mapping = {}
2327
+
2328
+ # load mapping for onnx domain
2327
2329
for target_opset , op_map in _OPSETS :
2328
2330
if target_opset <= g .opset :
2329
2331
ops_mapping .update (op_map )
2330
2332
2331
- # apply custom ops on top of the assembled opset. We can either completment the opset
2333
+ # load mapping for known extra opsets
2334
+ # order matters, later mapping overrides earlier's
2335
+ if g .extra_opset is not None :
2336
+ for extra_opset in g .extra_opset :
2337
+ if extra_opset .domain == constants .MICROSOFT_DOMAIN :
2338
+ # microsoft domain
2339
+ for target_opset , op_map in custom .ms .OPSETS :
2340
+ if target_opset <= extra_opset .version :
2341
+ ops_mapping .update (op_map )
2342
+ else :
2343
+ # unknown opset, assume used in custom_op_handlers, skip it
2344
+ pass
2345
+
2346
+ # apply custom ops on top of the assembled opset. We can either complement the opset
2332
2347
# or override existing ops with a custom op.
2333
2348
if custom_op_handlers is not None :
2334
2349
custom_opset = {k : v for k , v in custom_op_handlers .items ()}
@@ -2337,7 +2352,7 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
2337
2352
ops = [n for n in g .get_nodes ()]
2338
2353
for node in ops :
2339
2354
if node .need_skip ():
2340
- log .debug ("explictly skip node " + node .name )
2355
+ log .debug ("explicitly skip node " + node .name )
2341
2356
continue
2342
2357
2343
2358
op = node .type
0 commit comments