Skip to content

Commit c4a2142

Browse files
Added handlers for tflite (#1270)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent c353349 commit c4a2142

File tree

6 files changed

+560
-0
lines changed

6 files changed

+560
-0
lines changed

tf2onnx/tflite_handlers/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
"""tf2onnx.tflite_handlers module"""
4+
5+
from . import (
6+
tfl_math,
7+
tfl_nn,
8+
tfl_controlflow,
9+
tfl_direct,
10+
tfl_tensor
11+
)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tfl_controlflow
6+
"""
7+
8+
import copy
9+
import numpy as np
10+
from onnx.onnx_pb import TensorProto
11+
12+
from tf2onnx.handler import tfl_op
13+
from tf2onnx import utils
14+
from tf2onnx.tf_loader import find_function
15+
from tf2onnx.onnx_opset.controlflow import parameter_binding, inline_subgraph
16+
17+
18+
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
19+
20+
21+
@tfl_op(["TFL_WHILE"])
22+
class TflWhile:
23+
@classmethod
24+
def version_7(cls, ctx, node, **kwargs):
25+
tfl_while_inputs = node.input
26+
output_shapes = node.output_shapes
27+
output_dtypes = node.output_dtypes
28+
output_names = node.output
29+
30+
cond_name = node.get_attr_str("cond_subgraph_index")
31+
cond_graph = find_function(cond_name)
32+
cond_graph.parent_graph = ctx
33+
34+
body_name = node.get_attr_str("body_subgraph_index")
35+
body = find_function(body_name)
36+
body.parent_graph = ctx
37+
38+
ctx.remove_node(node.name)
39+
40+
cond_binding = parameter_binding(cond_graph, tfl_while_inputs)
41+
cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding)
42+
43+
max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max))
44+
45+
loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs,
46+
output_count=len(output_shapes), name=node.name + "_loop",
47+
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True)
48+
49+
output_map = dict(zip(output_names, loop_node.output))
50+
51+
# shift output consumers
52+
for k, v in output_map.items():
53+
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()
54+
55+
body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph)
56+
57+
loop_node.set_body_graph_as_attr("body", body)
58+
59+
def wire_tfl_while_body(g, loop_node_inputs, output_shapes,
60+
output_dtypes, cond_graph):
61+
"""Wire subgraph graph into main."""
62+
63+
g = copy.deepcopy(g)
64+
65+
# onnx will pass in cond as argument
66+
iter_node = g.make_node("Placeholder", [], name=utils.make_name("iteration_num"),
67+
output_count=1, dtypes=[TensorProto.INT64], shapes=[[]])
68+
cond_node = g.make_node("Placeholder", [], name=utils.make_name("cond"),
69+
output_count=1, dtypes=[TensorProto.BOOL], shapes=[[]])
70+
cond_binding = parameter_binding(cond_graph, g.outputs)
71+
72+
# in onnx the body inputs are: index, cond, [loop_vars]
73+
g.func_inputs = [iter_node.output[0], cond_node.output[0]] + g.func_inputs
74+
# tell graph lib to keep inputs in order
75+
g._order_sensitive_inputs = \
76+
[g.get_node_by_output(name) for name in g.func_inputs] # pylint: disable=protected-access
77+
78+
for p, c in zip(loop_node_inputs, g.func_inputs):
79+
shape = p.output_shapes[0]
80+
g.set_shape(c, shape)
81+
82+
cond_outputs = inline_subgraph(g, cond_graph, "cond__", cond_binding)
83+
84+
g.outputs = [cond_outputs[0]] + g.outputs
85+
return g
86+
87+
@tfl_op(["TFL_IF"], tf_op="If")
88+
class TflIfOp:
89+
@classmethod
90+
def to_tf(cls, ctx, node, **kwargs):
91+
node.attr["then_branch"] = node.attr["then_subgraph_index"]
92+
del node.attr["then_subgraph_index"]
93+
node.attr["else_branch"] = node.attr["else_subgraph_index"]
94+
del node.attr["else_subgraph_index"]

tf2onnx/tflite_handlers/tfl_direct.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tfl_direct
6+
"""
7+
8+
from tf2onnx.handler import tfl_op
9+
10+
11+
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
12+
13+
14+
@tfl_op("TFL_ABS", tf_op="Abs")
15+
@tfl_op("TFL_CEIL", tf_op="Ceil")
16+
@tfl_op("TFL_COS", tf_op="Cos")
17+
@tfl_op("TFL_ELU", tf_op="Elu")
18+
@tfl_op("TFL_EQUAL", tf_op="Equal")
19+
@tfl_op("TFL_EXP", tf_op="Exp")
20+
@tfl_op("TFL_FLOOR", tf_op="Floor")
21+
@tfl_op("TFL_FLOOR_DIV", tf_op="FloorDiv")
22+
@tfl_op("TFL_FLOOR_MOD", tf_op="FloorMod")
23+
@tfl_op("TFL_GREATER", tf_op="Greater")
24+
@tfl_op("TFL_GREATER_EQUAL", tf_op="GreaterEqual")
25+
@tfl_op("TFL_LESS", tf_op="Less")
26+
@tfl_op("TFL_LESS_EQUAL", tf_op="LessEqual")
27+
@tfl_op("TFL_LOG", tf_op="Log")
28+
@tfl_op("TFL_LOG_SOFTMAX", tf_op="LogSoftmax")
29+
@tfl_op("TFL_LOGICAL_AND", tf_op="LogicalAnd")
30+
@tfl_op("TFL_LOGICAL_NOT", tf_op="LogicalNot")
31+
@tfl_op("TFL_LOGICAL_OR", tf_op="LogicalOr")
32+
@tfl_op("TFL_MATRIX_DIAG", tf_op="MatrixDiag")
33+
@tfl_op("TFL_MATRIX_SET_DIAG", tf_op="MatrixSetDiag")
34+
@tfl_op("TFL_MAXIMUM", tf_op="Maximum")
35+
@tfl_op("TFL_MINIMUM", tf_op="Minimum")
36+
@tfl_op("TFL_NEG", tf_op="Neg")
37+
@tfl_op("TFL_NOT_EQUAL", tf_op="NotEqual")
38+
@tfl_op("TFL_POW", tf_op="Pow")
39+
@tfl_op("TFL_RANK", tf_op="Rank")
40+
@tfl_op("TFL_RELU", tf_op="Relu")
41+
@tfl_op("TFL_RELU6", tf_op="Relu6")
42+
@tfl_op("TFL_ROUND", tf_op="Round")
43+
@tfl_op("TFL_RSQRT", tf_op="Rsqrt")
44+
@tfl_op("TFL_SELECT", tf_op="Select")
45+
@tfl_op("TFL_SELECT_V2", tf_op="SelectV2")
46+
@tfl_op("TFL_SIN", tf_op="Sin")
47+
@tfl_op("TFL_SQRT", tf_op="Sqrt")
48+
@tfl_op("TFL_SQUARE", tf_op="Square")
49+
@tfl_op("TFL_SQUARED_DIFFERENCE", tf_op="SquaredDifference")
50+
@tfl_op("TFL_TANH", tf_op="Tanh")
51+
@tfl_op("TFL_WHERE", tf_op="Where")
52+
@tfl_op("TFL_ZEROS_LIKE", tf_op="ZerosLike")
53+
@tfl_op("TFL_FILL", tf_op="Fill")
54+
@tfl_op("TFL_GATHER_ND", tf_op="GatherNd")
55+
@tfl_op("TFL_PAD", tf_op="Pad")
56+
@tfl_op("TFL_REVERSE_V2", tf_op="ReverseV2")
57+
@tfl_op("TFL_SCATTER_ND", tf_op="ScatterNd")
58+
@tfl_op("TFL_SEGMENT_SUM", tf_op="SegmentSum")
59+
@tfl_op("TFL_SHAPE", tf_op="Shape")
60+
@tfl_op("TFL_SLICE", tf_op="Slice")
61+
@tfl_op("TFL_SQUEEZE", tf_op="Squeeze")
62+
@tfl_op("TFL_TILE", tf_op="Tile")
63+
@tfl_op("TFL_EXPAND_DIMS", tf_op="ExpandDims")
64+
@tfl_op("TFL_TRANSPOSE", tf_op="Transpose")
65+
@tfl_op("TFL_UNPACK", tf_op="Unpack")
66+
@tfl_op("TFL_ADD_N", tf_op="AddN")
67+
@tfl_op("TFL_ONE_HOT", tf_op="OneHot")
68+
@tfl_op("TFL_DEPTH_TO_SPACE", tf_op="DepthToSpace")
69+
@tfl_op("TFL_ARG_MIN", tf_op="ArgMin")
70+
@tfl_op("TFL_ARG_MAX", tf_op="ArgMax")
71+
@tfl_op("TFL_NON_MAX_SUPPRESSION_V5", tf_op="NonMaxSuppressionV5")
72+
@tfl_op("TFL_RESIZE_NEAREST_NEIGHBOR", tf_op="ResizeNearestNeighbor")
73+
@tfl_op("TFL_LEAKY_RELU", tf_op="LeakyRelu")
74+
@tfl_op("TFL_STRIDED_SLICE", tf_op="StridedSlice")
75+
@tfl_op("TFL_MEAN", tf_op="Mean")
76+
@tfl_op("TFL_SUM", tf_op="Sum")
77+
@tfl_op("TFL_MIRROR_PAD", tf_op="MirrorPad")
78+
@tfl_op("TFL_RESIZE_BILINEAR", tf_op="ResizeBilinear")
79+
@tfl_op("TFL_REVERSE_SEQUENCE", tf_op="ReverseSequence")
80+
@tfl_op("TFL_SPARSE_TO_DENSE", tf_op="SparseToDense")
81+
@tfl_op("TFL_CUMSUM", tf_op="Cumsum")
82+
class TflDirectOp:
83+
@classmethod
84+
def to_tf(cls, ctx, node, **kwargs):
85+
pass

tf2onnx/tflite_handlers/tfl_math.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tfl_math
6+
"""
7+
8+
import logging
9+
import numpy as np
10+
from tf2onnx.handler import tfl_op
11+
from tf2onnx import utils
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
17+
18+
19+
def separate_fused_activation_function(ctx, node):
20+
activation_fn = node.attr['fused_activation_function'].s
21+
del node.attr['fused_activation_function']
22+
if activation_fn == b'RELU':
23+
ctx.insert_new_node_on_output("Relu", node.output[0])
24+
elif activation_fn == b'RELU6':
25+
new_node = ctx.insert_new_node_on_output("Relu6", node.output[0])
26+
new_node.skip_conversion = False
27+
elif activation_fn == b'TANH':
28+
ctx.insert_new_node_on_output("Tanh", node.output[0])
29+
else:
30+
# TODO: SIGN_BIT and RELU_N1_TO_1 not supported yet
31+
utils.make_sure(activation_fn == b'NONE', "Unsupported fused activation function %s on node %s",
32+
activation_fn, node.name)
33+
34+
@tfl_op(["TFL_ADD"], tf_op="Add")
35+
class TflAdd:
36+
@classmethod
37+
def to_tf(cls, ctx, node, **kwargs):
38+
separate_fused_activation_function(ctx, node)
39+
40+
@tfl_op(["TFL_SUB"], tf_op="Sub")
41+
class TflSub:
42+
@classmethod
43+
def to_tf(cls, ctx, node, **kwargs):
44+
separate_fused_activation_function(ctx, node)
45+
46+
@tfl_op(["TFL_MUL"], tf_op="Mul")
47+
class TflMul:
48+
@classmethod
49+
def to_tf(cls, ctx, node, **kwargs):
50+
separate_fused_activation_function(ctx, node)
51+
52+
@tfl_op(["TFL_DIV"], tf_op="Div")
53+
class TflDiv:
54+
@classmethod
55+
def to_tf(cls, ctx, node, **kwargs):
56+
separate_fused_activation_function(ctx, node)
57+
58+
@tfl_op(["TFL_LOGISTIC"], tf_op="Sigmoid")
59+
class TflLogistic:
60+
@classmethod
61+
def to_tf(cls, ctx, node, **kwargs):
62+
pass
63+
64+
@tfl_op(["TFL_REDUCE_MAX"], tf_op="Max")
65+
@tfl_op(["TFL_REDUCE_ANY"], tf_op="Any")
66+
@tfl_op(["TFL_REDUCE_PROD"], tf_op="Prod")
67+
class TflReduceOp:
68+
@classmethod
69+
def to_tf(cls, ctx, node, **kwargs):
70+
pass
71+
72+
@tfl_op(["TFL_LOCAL_RESPONSE_NORMALIZATION"], tf_op="LRN")
73+
class TFlLocalResponseNormalizationOp:
74+
@classmethod
75+
def to_tf(cls, ctx, node, **kwargs):
76+
node.attr["depth_radius"] = node.attr["radius"]
77+
del node.attr["radius"]
78+
79+
@tfl_op(["TFL_RANGE"], tf_op="Range")
80+
class TflRangeOp:
81+
@classmethod
82+
def to_tf(cls, ctx, node, **kwargs):
83+
node.set_attr("Tidx", ctx.get_dtype(node.output[0]))
84+
85+
@tfl_op(["TFL_QUANTIZE"], onnx_op="QuantizeLinear")
86+
class TflQuantizeOp:
87+
@classmethod
88+
def version_10(cls, ctx, node, **kwargs):
89+
scale = node.get_attr_value('scale')
90+
zero_point = node.get_attr_value('zero_point')
91+
axis = node.get_attr_value('quantized_dimension')
92+
np_q_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.output[0]))
93+
if len(scale) > 1 or len(zero_point) > 1:
94+
node.set_attr("axis", axis)
95+
scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale[0], dtype=np.float32))
96+
zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point[0], dtype=np_q_type))
97+
ctx.replace_inputs(node, [node.input[0], scale_node.output[0], zero_point_node.output[0]])
98+
del node.attr["scale"]
99+
del node.attr["zero_point"]
100+
del node.attr["quantized_dimension"]
101+
102+
@tfl_op(["TFL_DEQUANTIZE"], onnx_op="DequantizeLinear")
103+
class TflDequantizeOp:
104+
@classmethod
105+
def version_10(cls, ctx, node, **kwargs):
106+
scale = node.get_attr_value('scale')
107+
zero_point = node.get_attr_value('zero_point')
108+
axis = node.get_attr_value('quantized_dimension')
109+
np_q_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0]))
110+
if len(scale) > 1 or len(zero_point) > 1:
111+
utils.make_sure(ctx.opset >= 13, "Opset 13 is required for per-axis quantization")
112+
node.set_attr("axis", axis)
113+
scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale, dtype=np.float32))
114+
zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point, dtype=np_q_type))
115+
else:
116+
scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale[0], dtype=np.float32))
117+
zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point[0], dtype=np_q_type))
118+
ctx.replace_inputs(node, [node.input[0], scale_node.output[0], zero_point_node.output[0]])
119+
del node.attr["scale"]
120+
del node.attr["zero_point"]
121+
del node.attr["quantized_dimension"]
122+
123+
def dynamic_quantize_inputs(ctx, node):
124+
if ctx.opset < 11:
125+
logger.warning("Opset 11 is required for asymmetric_quantize_inputs of node %s", node.name)
126+
return
127+
for i in range(len(node.input)):
128+
# Don't quantize inputs that are already quantized
129+
if node.inputs[i].type in ["DequantizeLinear", "TFL_DEQUANTIZE"]:
130+
continue
131+
dyn_quant = ctx.make_node("DynamicQuantizeLinear", [node.input[i]], output_count=3, op_name_scope=node.name)
132+
dyn_quant.skip_conversion = True
133+
dequant = ctx.make_node("DequantizeLinear", dyn_quant.output, op_name_scope=node.name)
134+
dequant.skip_conversion = True
135+
ctx.replace_input(node, node.input[i], dequant.output[0], input_index=i)
136+
137+
@tfl_op(["TFL_FULLY_CONNECTED"])
138+
class TflFullyConnectedOp:
139+
@classmethod
140+
def to_tf(cls, ctx, node, **kwargs):
141+
separate_fused_activation_function(ctx, node)
142+
utils.make_sure(node.attr['weights_format'].s == b'DEFAULT',
143+
"Only default weights format supported for fully connected op")
144+
utils.make_sure(node.attr['keep_num_dims'].i == 0,
145+
"Only keep_num_dims=False supported for fully connected op")
146+
if node.attr['asymmetric_quantize_inputs'].i == 1:
147+
dynamic_quantize_inputs(ctx, node)
148+
149+
transpose_node = ctx.insert_new_node_on_input(node, "Transpose", node.input[1],
150+
name=None, input_index=1, perm=[1, 0])
151+
transpose_node.skip_conversion = True
152+
node.set_attr("transpose_a", 0)
153+
node.set_attr("transpose_b", 0)
154+
node.type = "MatMul"
155+
156+
if len(node.input) == 3:
157+
# FIXME: Add a test for this
158+
bias_inp = node.input[2]
159+
ctx.replace_inputs(node, node.input[:2])
160+
add_node = ctx.insert_new_node_on_output("Add", node.output[0], inputs=[node.output[0], bias_inp])
161+
add_node.skip_conversion = True
162+
163+
del node.attr["weights_format"]
164+
del node.attr["keep_num_dims"]
165+
del node.attr["asymmetric_quantize_inputs"]
166+
167+
@tfl_op(["TFL_SOFTMAX"], tf_op="Softmax")
168+
class TFlSoftmaxOp:
169+
@classmethod
170+
def to_tf(cls, ctx, node, **kwargs):
171+
beta = node.get_attr_value("beta")
172+
beta_node = ctx.make_const(utils.make_name("beta"), np.array(beta, dtype=np.float32))
173+
mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], name=utils.make_name(node.name))
174+
ctx.replace_inputs(mul_node, [node.output[0], beta_node.output[0]])

0 commit comments

Comments
 (0)