Skip to content

Commit 51f12f4

Browse files
committed
refine custom opsets loading
1 parent 09cd571 commit 51f12f4

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

tf2onnx/custom_opsets/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,8 @@
33
""" custom tf2onnx mapping functions. """
44

55
from . import ms
6+
from tf2onnx import constants
7+
8+
DOMAIN_OPSETS = {
9+
constants.MICROSOFT_DOMAIN: ms.OPSETS
10+
}

tf2onnx/tfonnx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,9 +2334,9 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
23342334
# order matters, later mapping overrides earlier's
23352335
if g.extra_opset is not None:
23362336
for extra_opset in g.extra_opset:
2337-
if extra_opset.domain == constants.MICROSOFT_DOMAIN:
2338-
# microsoft domain
2339-
for target_opset, op_map in custom_opsets.ms.OPSETS:
2337+
opsets = custom_opsets.DOMAIN_OPSETS.get(extra_opset.domain, None)
2338+
if opsets is not None:
2339+
for target_opset, op_map in opsets:
23402340
if target_opset <= extra_opset.version:
23412341
ops_mapping.update(op_map)
23422342
else:

0 commit comments

Comments
 (0)