Skip to content

Commit 99eb959

Browse files
Added @tfl_op decorator to handlers (#1267)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent fa16f33 commit 99eb959

File tree

1 file changed

+57
-20
lines changed

1 file changed

+57
-20
lines changed

tf2onnx/handler.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,60 @@
1919
class tf_op:
2020
"""Class to implement the decorator to register handlers that map tf to onnx."""
2121

22+
# Maps domains (string) to lists (idx represents opset) of dicts (key = op to handle, value = handler)
2223
_OPSETS = collections.OrderedDict()
24+
# Cache of mapping for current domain and opset. Maps op names to handlers [(func, kwargs) tuple]
2325
_MAPPING = None
26+
# Cache of mapping from domain to map of op name to handlers. Used to fetch handlers from different domains
2427
_DOMAIN_MAPPING = None
2528

2629
def __init__(self, name, domain=constants.ONNX_DOMAIN, **kwargs):
2730
"""Called decorator from decorator.
2831
29-
:param name: The name of the tensorflow operator.
30-
:param domain: The domain the operator belongs to, defaults to onnx.
32+
:param name: The name (or list of names) of the tensorflow operator.
33+
:param domain: The domain the handler requires, defaults to onnx.
3134
:param kwargs: Dictionary that are passed to the handler. A key 'onnx_op' will change the operator name.
3235
"""
3336
if not isinstance(name, list):
3437
name = [name]
35-
self.name = name
38+
self.names = name
3639
self.domain = domain
3740
self.kwargs = kwargs
3841

3942
def __call__(self, func):
40-
opset = tf_op._OPSETS.get(self.domain)
41-
if not opset:
42-
opset = []
43-
tf_op._OPSETS[self.domain] = opset
4443
for k, v in inspect.getmembers(func, inspect.ismethod):
4544
if k.startswith("version_"):
4645
version = int(k.replace("version_", ""))
47-
while version >= len(opset):
48-
opset.append({})
49-
opset_dict = opset[version]
50-
for name in self.name:
51-
opset_dict[name] = (v, self.kwargs)
46+
tf_op.register_handler(v, version, self.names, self.domain, self.kwargs)
5247
return func
5348

5449
def register_compat_handler(self, func, version):
5550
"""Register old style custom handler.
5651
5752
:param func: The handler.
58-
:param version: The domain the operator belongs to, defaults to onnx.
5953
:param version: The version of the handler.
6054
"""
61-
opset = tf_op._OPSETS.get(self.domain)
55+
tf_op.register_handler(func, version, self.names, self.domain, self.kwargs)
56+
57+
@staticmethod
58+
def register_handler(func, version, names, domain, kwargs):
59+
"""Register handler.
60+
61+
:param func: The handler.
62+
:param version: (int) The opset of onnx (or other domain) required for the handler.
63+
:param names: List of names of the operators to convert.
64+
:param domain: The domain the handler requires, defaults to onnx.
65+
66+
"""
67+
opset = tf_op._OPSETS.get(domain)
6268
if not opset:
6369
opset = []
64-
tf_op._OPSETS[self.domain] = opset
65-
while version >= len(opset):
66-
opset.append({})
67-
opset_dict = opset[version]
68-
opset_dict[self.name[0]] = (func, self.kwargs)
70+
tf_op._OPSETS[domain] = opset
71+
while version >= len(opset):
72+
opset.append({})
73+
opset_dict = opset[version]
74+
for name in names:
75+
opset_dict[name] = (func, kwargs)
6976

7077
@staticmethod
7178
def get_opsets():
@@ -100,7 +107,7 @@ def create_mapping(max_onnx_opset_version, extra_opsets):
100107
def find_effective_op(name, domain=None):
101108
"""Find the effective version of an op create_mapping.
102109
This is used if we need to compose ops from other ops where we'd need to find the
103-
op that is doing to be used in the final graph, for example there is a custom op
110+
op that is going to be used in the final graph, for example there is a custom op
104111
that overrides a onnx op ...
105112
106113
:param name: The operator name.
@@ -113,3 +120,33 @@ def find_effective_op(name, domain=None):
113120
if map_info is None:
114121
return None
115122
return map_info
123+
124+
125+
class tfl_op:
126+
"""Class to implement the decorator to register handlers that map tflite to tf or onnx."""
127+
128+
def __init__(self, name, domain=constants.ONNX_DOMAIN, **kwargs):
129+
"""Called decorator from decorator.
130+
131+
:param name: The name (or list of names) of the tflite operator.
132+
:param domain: The domain the operator belongs to, defaults to onnx. Use 'com.google.tensorflow' for tflite->tf
133+
:param kwargs: Dictionary that are passed to the handler. A key 'onnx_op' will change the operator name.
134+
'tf_op' will convert the op to tf during a tflite to tf conversion pass.
135+
"""
136+
if not isinstance(name, list):
137+
name = [name]
138+
self.names = name
139+
self.domain = domain
140+
self.kwargs = kwargs
141+
142+
def __call__(self, func):
143+
# Register any handlers of the form 'version_#'
144+
tf_op(self.names, self.domain, **self.kwargs)(func)
145+
# TFLite to TF handlers have the function name 'to_tf' which takes the optional 'tf_op' kwarg
146+
if hasattr(func, 'to_tf'):
147+
tf_op.register_handler(func.to_tf, 0, self.names, 'com.google.tensorflow', self.kwargs)
148+
return func
149+
150+
@staticmethod
151+
def create_tfl_to_tf_mapping():
152+
return tf_op.get_opsets()['com.google.tensorflow'][0]

0 commit comments

Comments
 (0)