19
19
class tf_op :
20
20
"""Class to implement the decorator to register handlers that map tf to onnx."""
21
21
22
+ # Maps domains (string) to lists (idx represents opset) of dicts (key = op to handle, value = handler)
22
23
_OPSETS = collections .OrderedDict ()
24
+ # Cache of mapping for current domain and opset. Maps op names to handlers [(func, kwargs) tuple]
23
25
_MAPPING = None
26
+ # Cache of mapping from domain to map of op name to handlers. Used to fetch handlers from different domains
24
27
_DOMAIN_MAPPING = None
25
28
26
29
def __init__ (self , name , domain = constants .ONNX_DOMAIN , ** kwargs ):
27
30
"""Called decorator from decorator.
28
31
29
- :param name: The name of the tensorflow operator.
30
- :param domain: The domain the operator belongs to , defaults to onnx.
32
+ :param name: The name (or list of names) of the tensorflow operator.
33
+ :param domain: The domain the handler requires , defaults to onnx.
31
34
:param kwargs: Dictionary that are passed to the handler. A key 'onnx_op' will change the operator name.
32
35
"""
33
36
if not isinstance (name , list ):
34
37
name = [name ]
35
- self .name = name
38
+ self .names = name
36
39
self .domain = domain
37
40
self .kwargs = kwargs
38
41
39
42
def __call__ (self , func ):
40
- opset = tf_op ._OPSETS .get (self .domain )
41
- if not opset :
42
- opset = []
43
- tf_op ._OPSETS [self .domain ] = opset
44
43
for k , v in inspect .getmembers (func , inspect .ismethod ):
45
44
if k .startswith ("version_" ):
46
45
version = int (k .replace ("version_" , "" ))
47
- while version >= len (opset ):
48
- opset .append ({})
49
- opset_dict = opset [version ]
50
- for name in self .name :
51
- opset_dict [name ] = (v , self .kwargs )
46
+ tf_op .register_handler (v , version , self .names , self .domain , self .kwargs )
52
47
return func
53
48
54
49
def register_compat_handler (self , func , version ):
55
50
"""Register old style custom handler.
56
51
57
52
:param func: The handler.
58
- :param version: The domain the operator belongs to, defaults to onnx.
59
53
:param version: The version of the handler.
60
54
"""
61
- opset = tf_op ._OPSETS .get (self .domain )
55
+ tf_op .register_handler (func , version , self .names , self .domain , self .kwargs )
56
+
57
+ @staticmethod
58
+ def register_handler (func , version , names , domain , kwargs ):
59
+ """Register handler.
60
+
61
+ :param func: The handler.
62
+ :param version: (int) The opset of onnx (or other domain) required for the handler.
63
+ :param names: List of names of the operators to convert.
64
+ :param domain: The domain the handler requires, defaults to onnx.
65
+
66
+ """
67
+ opset = tf_op ._OPSETS .get (domain )
62
68
if not opset :
63
69
opset = []
64
- tf_op ._OPSETS [self .domain ] = opset
65
- while version >= len (opset ):
66
- opset .append ({})
67
- opset_dict = opset [version ]
68
- opset_dict [self .name [0 ]] = (func , self .kwargs )
70
+ tf_op ._OPSETS [domain ] = opset
71
+ while version >= len (opset ):
72
+ opset .append ({})
73
+ opset_dict = opset [version ]
74
+ for name in names :
75
+ opset_dict [name ] = (func , kwargs )
69
76
70
77
@staticmethod
71
78
def get_opsets ():
@@ -100,7 +107,7 @@ def create_mapping(max_onnx_opset_version, extra_opsets):
100
107
def find_effective_op (name , domain = None ):
101
108
"""Find the effective version of an op create_mapping.
102
109
This is used if we need to compose ops from other ops where we'd need to find the
103
- op that is doing to be used in the final graph, for example there is a custom op
110
+ op that is going to be used in the final graph, for example there is a custom op
104
111
that overrides a onnx op ...
105
112
106
113
:param name: The operator name.
@@ -113,3 +120,33 @@ def find_effective_op(name, domain=None):
113
120
if map_info is None :
114
121
return None
115
122
return map_info
123
+
124
+
125
+ class tfl_op :
126
+ """Class to implement the decorator to register handlers that map tflite to tf or onnx."""
127
+
128
+ def __init__ (self , name , domain = constants .ONNX_DOMAIN , ** kwargs ):
129
+ """Called decorator from decorator.
130
+
131
+ :param name: The name (or list of names) of the tflite operator.
132
+ :param domain: The domain the operator belongs to, defaults to onnx. Use 'com.google.tensorflow' for tflite->tf
133
+ :param kwargs: Dictionary that are passed to the handler. A key 'onnx_op' will change the operator name.
134
+ 'tf_op' will convert the op to tf during a tflite to tf conversion pass.
135
+ """
136
+ if not isinstance (name , list ):
137
+ name = [name ]
138
+ self .names = name
139
+ self .domain = domain
140
+ self .kwargs = kwargs
141
+
142
+ def __call__ (self , func ):
143
+ # Register any handlers of the form 'version_#'
144
+ tf_op (self .names , self .domain , ** self .kwargs )(func )
145
+ # TFLite to TF handlers have the function name 'to_tf' which takes the optional 'tf_op' kwarg
146
+ if hasattr (func , 'to_tf' ):
147
+ tf_op .register_handler (func .to_tf , 0 , self .names , 'com.google.tensorflow' , self .kwargs )
148
+ return func
149
+
150
+ @staticmethod
151
+ def create_tfl_to_tf_mapping ():
152
+ return tf_op .get_opsets ()['com.google.tensorflow' ][0 ]
0 commit comments