Skip to content

Commit ccc6dfd

Browse files
committed
review feedback
1 parent 896dbb2 commit ccc6dfd

File tree

9 files changed

+55
-31
lines changed

9 files changed

+55
-31
lines changed

tests/test_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def print_handler(ctx, node, name, args):
365365
def test_custom_op(self):
366366
"""Custom op test."""
367367

368-
@tf_op("Print", type_map={"Print": "Identity"})
368+
@tf_op("Print", onnx_op="Identity")
369369
class Print:
370370
@classmethod
371371
def version_1(cls, ctx, node, **kwargs):

tf2onnx/custom_opsets/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,3 @@
33
""" custom tf2onnx mapping functions. """
44

55
from . import ms
6-
from .. import constants

tf2onnx/handler.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,18 @@
1616
# pylint: disable=unused-argument,missing-docstring,invalid-name
1717

1818
class tf_op:
19+
"""Class to implement the decorator to register handlers that map tf to onnx."""
20+
1921
_OPSETS = collections.OrderedDict()
2022
_MAPPING = None
2123

2224
def __init__(self, name, domain=constants.ONNX_DOMAIN, **kwargs):
25+
"""Called decorator from decorator.
26+
27+
:param name: The name of the tensorflow operator.
28+
:param domain: The domain the operator belongs to, defaults to onnx.
29+
:param kwargs: Dictionary that are passed to the handler. A key 'onnx_op' will change the operator name.
30+
"""
2331
if not isinstance(name, list):
2432
name = [name]
2533
self.name = name
@@ -39,8 +47,15 @@ def __call__(self, func):
3947
opset_dict = opset[version]
4048
for name in self.name:
4149
opset_dict[name] = (v, self.kwargs)
50+
return func
4251

4352
def register_compat_handler(self, func, version):
53+
"""Register old style custom handler.
54+
55+
:param func: The handler.
56+
:param version: The domain the operator belongs to, defaults to onnx.
57+
:param version: The version of the handler.
58+
"""
4459
opset = tf_op._OPSETS.get(self.domain)
4560
if not opset:
4661
opset = []
@@ -55,8 +70,13 @@ def get_opsets():
5570
return tf_op._OPSETS
5671

5772
@staticmethod
58-
def create_mapping(max_opset, extra_opsets):
59-
mapping = {constants.ONNX_DOMAIN: max_opset}
73+
def create_mapping(max_onnx_opset_version, extra_opsets):
74+
"""Create the final mapping dictionary by stacking domains and opset versions.
75+
76+
:param max_onnx_opset_version: The highest onnx opset the resulting graph may use.
77+
:param extra_opsets: Extra opsets the resulting graph may use.
78+
"""
79+
mapping = {constants.ONNX_DOMAIN: max_onnx_opset_version}
6080
if extra_opsets:
6181
for extra_opset in extra_opsets:
6282
mapping[extra_opset.domain] = extra_opset.version
@@ -72,7 +92,14 @@ def create_mapping(max_opset, extra_opsets):
7292
return ops_mapping
7393

7494
@staticmethod
75-
def find_op(name):
95+
def find_effective_op(name):
96+
"""Find the effective version of an op create_mapping.
97+
This is used if we need to compose ops from other ops where we'd need to find the
98+
op that is doing to be used in the final graph, for example there is a custom op
99+
that overrides a onnx op ...
100+
101+
:param name: The operator name.
102+
"""
76103
map_info = tf_op._MAPPING.get(name)
77104
if map_info is None:
78105
return None

tf2onnx/onnx_opset/logical.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ def logical_compare_op(ctx, node, **kwargs):
3939
ctx.set_dtype(inp_cast.output[0], target_dtype)
4040

4141

42-
@tf_op(["LogicalNot", "NotEqual"],
43-
type_map={"LogicalNot": "Not"})
42+
@tf_op(["LogicalNot", "NotEqual"], onnx_op="Not")
4443
class DirectOp:
4544
@classmethod
4645
def version_4(cls, ctx, node, **kwargs):
4746
pass
4847

4948

50-
@tf_op(["Equal", "Greater", "Less", "LogicalAnd", "LogicalOr", "LogicalAnd", "LogicalOr"],
51-
type_map={"LogicalAnd": "And", "LogicalOr": "Or"})
49+
@tf_op(["Equal", "Greater", "Less"])
50+
@tf_op("LogicalAnd", onnx_op="And")
51+
@tf_op("LogicalOr", onnx_op="Or")
5252
class BroadcastOp(common.BroadcastOp):
5353
pass
5454

@@ -64,8 +64,8 @@ def version_7(cls, ctx, node, **kwargs):
6464
logical_compare_op(ctx, node, **kwargs)
6565

6666

67-
@tf_op(["GreaterEqual", "LessEqual"],
68-
type_map={"GreaterEqual": "Less", "LessEqual": "Greater"})
67+
@tf_op("GreaterEqual", onnx_op="Less")
68+
@tf_op("LessEqual", onnx_op="Greater")
6969
class GreaterLessEqual:
7070
@classmethod
7171
def version_7(cls, ctx, node, **kwargs):

tf2onnx/onnx_opset/math.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ class BroadcastOp(common.BroadcastOp):
2727
pass
2828

2929

30-
@tf_op(["RealDiv", "TruncateDiv"],
31-
type_map={"RealDiv": "Div", "TruncateDiv": "Div"})
30+
@tf_op(["RealDiv", "TruncateDiv"], onnx_op="Div")
3231
class RealDiv(common.BroadcastOp):
3332
pass
3433

@@ -55,8 +54,8 @@ def version_9(cls, ctx, node, **kwargs):
5554
pass
5655

5756

58-
@tf_op(["Minimum", "Maximum"],
59-
type_map={"Minimum": "Min", "Maximum": "Max"})
57+
@tf_op("Minimum", onnx_op="Min")
58+
@tf_op("Maximum", onnx_op="Max")
6059
class MinMaxOp:
6160
@classmethod
6261
def version_4(cls, ctx, node, **kwargs):

tf2onnx/onnx_opset/nn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,8 @@ def version_4(cls, ctx, node, **kwargs):
287287
conv_convert_inputs(ctx, node, with_kernel=True, new_kernel_shape=new_kernel_shape)
288288

289289

290-
@tf_op(["AvgPool", "AvgPool3D", "MaxPool", "MaxPoolV2"],
291-
type_map={"AvgPool": "AveragePool", "AvgPool3D": "AveragePool",
292-
"MaxPool": "MaxPool", "MaxPoolV2": "MaxPool"})
290+
@tf_op(["AvgPool", "AvgPool3D"], onnx_op="AveragePool")
291+
@tf_op(["MaxPool", "MaxPoolV2"], onnx_op="MaxPool")
293292
class PoolOp:
294293
@classmethod
295294
def version_4(cls, ctx, node, **kwargs):

tf2onnx/onnx_opset/reduction.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222

2323
# pylint: disable=unused-argument,missing-docstring
2424

25-
@tf_op(["Min", "Max", "Mean", "Sum", "Prod"],
26-
type_map={"Min": "ReduceMin", "Max": "ReduceMax", "Mean": "ReduceMean",
27-
"Sum": "ReduceSum", "Prod": "ReduceProd"})
25+
@tf_op("Min", onnx_op="ReduceMin")
26+
@tf_op("Max", onnx_op="ReduceMax")
27+
@tf_op("Mean", onnx_op="ReduceMean")
28+
@tf_op("Sum", onnx_op="ReduceSum")
29+
@tf_op("Prod", onnx_op="ReduceProd")
2830
class ReduceOpBase:
2931
@classmethod
3032
def version_4(cls, ctx, node, **kwargs):

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ def version_4(cls, ctx, node, **kwargs):
787787
ctx.copy_shape(node.output[0], output_cast.output[0])
788788

789789

790-
@tf_op("IsNan", type_map={"IsNan": "IsNaN"})
790+
@tf_op("IsNan", onnx_op="IsNaN")
791791
class IsNan:
792792
@classmethod
793793
def version_9(cls, ctx, node, **kwargs):

tf2onnx/tfonnx.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def rewrite_conv2d_with_pad(g, ops):
512512
g.replace_input(conv, conv.input[0], pad.input[0])
513513
# convert Conv2D
514514
conv.type = "Conv"
515-
func, info = handler.tf_op.find_op("Conv2D")
515+
func, info = handler.tf_op.find_effective_op("Conv2D")
516516
func(g, conv)
517517
conv.skip_conversion = True
518518
conv.set_attr("auto_pad", "NOTSET")
@@ -541,12 +541,10 @@ def tensorflow_onnx_mapping(g, continue_on_error, ops_mapping):
541541
mapped_op[op] += 1
542542
func, kwargs = map_info
543543
if kwargs:
544-
# if there is a type_map key we'll map the old type to a new type
545-
type_map = kwargs.get("type_map")
546-
if type_map:
547-
new_type = type_map.get(node.type)
548-
if new_type:
549-
node.type = new_type
544+
# if there is a onnx_op key we'll map the old type to a new type
545+
onnx_op = kwargs.get("onnx_op")
546+
if onnx_op:
547+
node.type = onnx_op
550548
try:
551549
body_graphs = node.get_body_graphs()
552550
if body_graphs:
@@ -736,8 +734,8 @@ def compat_handler(ctx, node, **kwargs):
736734
args = v[1]
737735
kwargs = {"func": v[0]}
738736
if args:
739-
type_map = {k: args[0]}
740-
kwargs["type_map"] = type_map
737+
onnx_op = args[0]
738+
kwargs["onnx_op"] = onnx_op
741739
args = args[1:]
742740
kwargs["args"] = args
743741
new_handler = handler.tf_op(k,

0 commit comments

Comments
 (0)