Skip to content

Commit 9bdd009

Browse files
Created tflite utils for parsing tflite graphs (#1265)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent df6fdd3 commit 9bdd009

File tree

2 files changed

+378
-0
lines changed

2 files changed

+378
-0
lines changed

tests/test_tflite_utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""Unit Tests for TFLite utils."""
5+
6+
import os
7+
import tensorflow as tf
8+
9+
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
10+
from backend_test_base import Tf2OnnxBackendTestBase
11+
from tf2onnx.tf_loader import from_function, tf_session
12+
from tf2onnx.tflite_utils import read_tflite_model, parse_tflite_graph
13+
14+
# pylint: disable=missing-docstring
15+
16+
17+
class TFListUtilsTests(Tf2OnnxBackendTestBase):
18+
19+
@check_tf_min_version("2.0")
20+
def test_parse_tflite_graph(self):
21+
22+
def func(a, b, c):
23+
alpha = tf.constant(1.1, dtype=tf.float32)
24+
beta = tf.constant(2.3, dtype=tf.float32)
25+
mul1 = tf.multiply(alpha, tf.matmul(a, b))
26+
mul2 = tf.multiply(beta, c)
27+
x_ = mul1 + mul2
28+
return tf.identity(x_, name="output")
29+
30+
inp_shapes = [[2, 3], [3, 1], [2, 1]]
31+
inp_dtypes = [tf.float32, tf.float32, tf.float32]
32+
names = ['a', 'b', 'c']
33+
names_with_port = ['a:0', 'b:0', 'c:0']
34+
output_names = ['output']
35+
output_names_with_port = ['output:0']
36+
37+
input_tensors = [tf.TensorSpec(shape=s, dtype=d, name=n) for s, d, n in zip(inp_shapes, inp_dtypes, names)]
38+
39+
concrete_func = tf.function(func, input_signature=tuple(input_tensors))
40+
concrete_func = concrete_func.get_concrete_function()
41+
graph_def = from_function(concrete_func,
42+
input_names=names_with_port,
43+
output_names=output_names_with_port)
44+
with tf_session() as sess:
45+
tf.import_graph_def(graph_def, name='')
46+
sess_inputs = [sess.graph.get_tensor_by_name(k) for k in names_with_port]
47+
sess_outputs = [sess.graph.get_tensor_by_name(n) for n in output_names_with_port]
48+
converter = tf.compat.v1.lite.TFLiteConverter.from_session(sess, sess_inputs, sess_outputs)
49+
50+
tflite_model = converter.convert()
51+
tflite_path = os.path.join(self.test_data_directory, self._testMethodName + ".tflite")
52+
dir_name = os.path.dirname(tflite_path)
53+
tflite_model = converter.convert()
54+
os.makedirs(dir_name, exist_ok=True)
55+
with open(tflite_path, 'wb') as f:
56+
f.write(tflite_model)
57+
58+
tflite_graphs, opcodes_map, model = read_tflite_model(tflite_path)
59+
self.assertEqual(1, len(tflite_graphs))
60+
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, _ = \
61+
parse_tflite_graph(tflite_graphs[0], opcodes_map, model)
62+
self.assertEqual(2, op_cnt['MUL'])
63+
self.assertEqual(1, op_cnt['ADD'])
64+
self.assertEqual(1, op_cnt['FULLY_CONNECTED'])
65+
66+
self.assertEqual(1, attr_cnt['WeightsFormat'])
67+
self.assertEqual(names, inputs)
68+
self.assertEqual(output_names, outputs)
69+
70+
for name, shape, dtype in zip(names, inp_shapes, inp_dtypes):
71+
self.assertEqual(shape, output_shapes[name])
72+
self.assertEqual(dtype, dtypes[name])
73+
74+
self.assertTrue(len(onnx_nodes) >= 4)

tf2onnx/tflite_utils.py

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.tflite_utils - utilities for parsing tflite files into onnx graph
6+
"""
7+
8+
import collections
9+
import importlib
10+
11+
from onnx import helper, onnx_pb, numpy_helper
12+
from tensorflow.core.framework import types_pb2, tensor_pb2
13+
from tensorflow.python.framework import tensor_util
14+
from tflite.TensorType import TensorType as TFLiteTensorType
15+
from tflite.Model import Model
16+
17+
18+
TFLITE_TO_ONNX_DTYPE = {
19+
TFLiteTensorType.FLOAT32: onnx_pb.TensorProto.FLOAT,
20+
TFLiteTensorType.FLOAT16: onnx_pb.TensorProto.FLOAT16,
21+
TFLiteTensorType.INT32: onnx_pb.TensorProto.INT32,
22+
TFLiteTensorType.UINT8: onnx_pb.TensorProto.UINT8,
23+
TFLiteTensorType.INT64: onnx_pb.TensorProto.INT64,
24+
TFLiteTensorType.STRING: onnx_pb.TensorProto.STRING,
25+
TFLiteTensorType.BOOL: onnx_pb.TensorProto.BOOL,
26+
TFLiteTensorType.INT16: onnx_pb.TensorProto.INT16,
27+
TFLiteTensorType.COMPLEX64: onnx_pb.TensorProto.COMPLEX64,
28+
TFLiteTensorType.INT8: onnx_pb.TensorProto.INT8,
29+
TFLiteTensorType.FLOAT64: onnx_pb.TensorProto.DOUBLE,
30+
TFLiteTensorType.COMPLEX128: onnx_pb.TensorProto.COMPLEX128,
31+
TFLiteTensorType.UINT64: onnx_pb.TensorProto.UINT64,
32+
}
33+
34+
35+
TFLITE_TO_TF_DTYPE = {
36+
TFLiteTensorType.FLOAT32: types_pb2.DT_FLOAT,
37+
TFLiteTensorType.FLOAT16: types_pb2.DT_HALF,
38+
TFLiteTensorType.INT32: types_pb2.DT_INT32,
39+
TFLiteTensorType.UINT8: types_pb2.DT_UINT8,
40+
TFLiteTensorType.INT64: types_pb2.DT_INT64,
41+
TFLiteTensorType.STRING: types_pb2.DT_STRING,
42+
TFLiteTensorType.BOOL: types_pb2.DT_BOOL,
43+
TFLiteTensorType.INT16: types_pb2.DT_INT16,
44+
TFLiteTensorType.COMPLEX64: types_pb2.DT_COMPLEX64,
45+
TFLiteTensorType.INT8: types_pb2.DT_INT8,
46+
TFLiteTensorType.FLOAT64: types_pb2.DT_DOUBLE,
47+
TFLiteTensorType.COMPLEX128: types_pb2.DT_COMPLEX128,
48+
TFLiteTensorType.UINT64: types_pb2.DT_UINT64,
49+
}
50+
51+
52+
def map_tflite_dtype_to_onnx(dtype):
53+
return TFLITE_TO_ONNX_DTYPE[dtype]
54+
55+
56+
def map_tflite_dtype_to_tf(dtype):
57+
return TFLITE_TO_TF_DTYPE[dtype]
58+
59+
60+
# The tflite schema uses snake case, but the python bindings use proper case
61+
def snake_to_proper_case(name):
62+
return ''.join(n.capitalize() for n in name.split('_'))
63+
64+
65+
def proper_to_snake_case(name):
66+
res = ''
67+
for c in name:
68+
if c.isupper() and res:
69+
res += '_'
70+
res += c.lower()
71+
return res
72+
73+
# Pulled from the tflite schema.fbs file. Needed to decode enum numbers into strings.
74+
NODE_ATTR_NAME_TO_ENUM_TYPE = {
75+
'fused_activation_function': 'ActivationFunctionType',
76+
'padding': 'Padding',
77+
'type': 'LSHProjectionType',
78+
'weights_format': 'FullyConnectedOptionsWeightsFormat',
79+
'kernel_type': 'LSTMKernelType',
80+
'combiner': 'CombinerType',
81+
'in_data_type': 'TensorType',
82+
'out_data_type': 'TensorType',
83+
'output_type': 'TensorType',
84+
'out_type': 'TensorType',
85+
'mode': 'MirrorPadMode',
86+
'idx_out_type': 'TensorType',
87+
}
88+
NODE_ATTR_NAME_TO_ENUM_TYPE = {snake_to_proper_case(key): value for key, value in NODE_ATTR_NAME_TO_ENUM_TYPE.items()}
89+
90+
# Pulled from the tflite schema.fbs file.
91+
FUNCTION_ATTRS = ['then_subgraph_index', 'else_subgraph_index', 'cond_subgraph_index',
92+
'body_subgraph_index', 'subgraph']
93+
FUNCTION_ATTRS = [snake_to_proper_case(attr) for attr in FUNCTION_ATTRS]
94+
95+
96+
enum_cache = {}
97+
def lookup_enum(idx, enum_name):
98+
"""Given the name of a tflite enum class and an index, return a string with the name of the enum value"""
99+
if enum_name == 'TensorType':
100+
return map_tflite_dtype_to_onnx(idx)
101+
if enum_name in enum_cache:
102+
return enum_cache[enum_name][idx]
103+
module = importlib.import_module('tflite.' + enum_name)
104+
enum_class = getattr(module, enum_name)
105+
idx_to_name = {value: key for key, value in enum_class.__dict__.items() if not key.startswith('_')}
106+
enum_cache[enum_name] = idx_to_name
107+
return idx_to_name[idx]
108+
109+
110+
def get_options_class(name):
111+
"""Each tflite optype has a flatbuffer Options class (ex: AddOptions). Returns the options class given its name."""
112+
if name == "NONE":
113+
return None
114+
module = importlib.import_module('tflite.' + name)
115+
return getattr(module, name)
116+
117+
118+
def read_tflite_model(tflite_path):
119+
"""
120+
Given the path to a tflite model, returns tuple (tflite_graphs, opcodes_map, model)
121+
Pass these to parse_tflite_graph
122+
"""
123+
with open(tflite_path, 'rb') as f:
124+
buf = f.read()
125+
buf = bytearray(buf)
126+
model = Model.GetRootAsModel(buf, 0)
127+
# To save space, each op in the model indicates its opcode as an index into the model's opcode map.
128+
opcodes_map = {}
129+
for i in range(model.OperatorCodesLength()):
130+
op_code = model.OperatorCodes(i)
131+
# TFlite ran out of opcodes since they only used a byte. Old models store opcodes in DeprecatedBuiltinCode.
132+
# New models put PLACEHOLDER_FOR_GREATER_OP_CODES in this field to signify that BuiltinCode should be used.
133+
code = lookup_enum(op_code.DeprecatedBuiltinCode(), 'BuiltinOperator')
134+
if code == 'PLACEHOLDER_FOR_GREATER_OP_CODES':
135+
code = lookup_enum(op_code.BuiltinCode(), 'BuiltinOperator')
136+
opcodes_map[i] = code
137+
tflite_graphs = [model.Subgraphs(i) for i in range(model.SubgraphsLength())]
138+
return tflite_graphs, opcodes_map, model
139+
140+
141+
def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
142+
"""
143+
Returns a Graph object along with some op count stats. All tflite op types are prefixed with "TFL_".
144+
Names of graph inputs are optionally prefixed with a string to prevent name conflicts in subgraphs.
145+
Quantizatized tensors are surrounded with quantize/dequantize ops
146+
"""
147+
op_cnt = collections.Counter()
148+
attr_cnt = collections.Counter()
149+
onnx_nodes = []
150+
output_shapes = {}
151+
dtypes = {}
152+
tensor_names = {}
153+
# Map tensor name to tflite Tensor object so we can fetch quantization info as needed
154+
name_to_tensor = {}
155+
# If a node takes a quantized tensor as input, we must add a dequantize op after it.
156+
# Store a mapping so we only need to make at most one dequantize op per tensor.
157+
tensor_name_to_dequant_output = {}
158+
159+
# tflite uses generic names (arg0, arg1, etc.) for inputs but full names for other tensors, so
160+
# prefixing just the inputs should be fine. Other tensors are prefixed when we do inlining.
161+
input_indices = {tflite_g.Inputs(i) for i in range(tflite_g.InputsLength())}
162+
163+
for i in range(tflite_g.TensorsLength()):
164+
tensor = tflite_g.Tensors(i)
165+
name = tensor.Name().decode()
166+
if i in input_indices:
167+
name = input_prefix + name
168+
tensor_names[i] = name
169+
name_to_tensor[name] = tensor
170+
171+
if tensor.ShapeIsNone():
172+
output_shapes[name] = None
173+
elif tensor.ShapeSignatureIsNone():
174+
# The shape signature uses -1 to signify unknown dims. Old models don't have this and use Shape instead.
175+
output_shapes[name] = tensor.ShapeAsNumpy().tolist()
176+
else:
177+
output_shapes[name] = tensor.ShapeSignatureAsNumpy().tolist()
178+
buf = model.Buffers(tensor.Buffer())
179+
dtypes[name] = map_tflite_dtype_to_onnx(tensor.Type())
180+
if not buf.DataIsNone():
181+
# For const values we use TF to decode the binary data from the buffer
182+
t = tensor_pb2.TensorProto()
183+
t.tensor_content = buf.DataAsNumpy().tobytes()
184+
if output_shapes[name] is None:
185+
output_shapes[name] = []
186+
for d in output_shapes[name]:
187+
t.tensor_shape.dim.add().size = d
188+
t.dtype = map_tflite_dtype_to_tf(tensor.Type())
189+
np_data = tensor_util.MakeNdarray(t)
190+
onnx_tensor = numpy_helper.from_array(np_data, name=name)
191+
onnx_node = helper.make_node("Const", [], outputs=[name], name=name, value=onnx_tensor)
192+
onnx_nodes.append(onnx_node)
193+
op_cnt["Const"] += 1
194+
195+
def get_dequant(tensor_name):
196+
"""Creates a dequantize op for the provided tensor if needed and returns the output of the op, or
197+
the original tensor name if no dequantization is needed"""
198+
quant = name_to_tensor[tensor_name].Quantization()
199+
if quant is None or quant.ScaleIsNone() or quant.ZeroPointIsNone():
200+
return tensor_name
201+
if tensor_name in tensor_name_to_dequant_output:
202+
return tensor_name_to_dequant_output[tensor_name]
203+
dequant_name = tensor_name + "_dequant"
204+
attr = {}
205+
attr['scale'] = quant.ScaleAsNumpy().tolist()
206+
attr['zero_point'] = quant.ZeroPointAsNumpy().tolist()
207+
attr['quantized_dimension'] = quant.QuantizedDimension()
208+
onnx_node = helper.make_node("TFL_DEQUANTIZE", [tensor_name], [dequant_name], name=dequant_name, **attr)
209+
onnx_nodes.append(onnx_node)
210+
tensor_name_to_dequant_output[tensor_name] = dequant_name
211+
output_shapes[dequant_name] = output_shapes[tensor_name].copy()
212+
dtypes[dequant_name] = onnx_pb.TensorProto.FLOAT
213+
return dequant_name
214+
215+
def get_prequant(tensor_name):
216+
"""Called by nodes with the name of the tensor they must output.
217+
If the output is supposed to be quantized, creates a Quantize op outputting the tensor.
218+
Returns the name that should be used for the "prequantized" tensor, or the original tensor if no quantization
219+
is needed"""
220+
quant = name_to_tensor[tensor_name].Quantization()
221+
if quant is None or quant.ScaleIsNone() or quant.ZeroPointIsNone():
222+
return tensor_name
223+
prequant_name = tensor_name + "_prequant"
224+
quantize_name = tensor_name + "_quantize"
225+
attr = {}
226+
attr['scale'] = quant.ScaleAsNumpy().tolist()
227+
attr['zero_point'] = quant.ZeroPointAsNumpy().tolist()
228+
attr['quantized_dimension'] = quant.QuantizedDimension()
229+
onnx_node = helper.make_node("TFL_QUANTIZE", [prequant_name], [tensor_name], name=quantize_name, **attr)
230+
onnx_nodes.append(onnx_node)
231+
output_shapes[prequant_name] = output_shapes[tensor_name].copy()
232+
dtypes[prequant_name] = onnx_pb.TensorProto.FLOAT
233+
return prequant_name
234+
235+
for i in range(tflite_g.OperatorsLength()):
236+
op = tflite_g.Operators(i)
237+
optype = opcodes_map[op.OpcodeIndex()]
238+
op_cnt[optype] += 1
239+
attr = {}
240+
options_type_name = lookup_enum(op.BuiltinOptionsType(), 'BuiltinOptions')
241+
option_class = get_options_class(options_type_name)
242+
wants_dequantized_input = True
243+
has_prequantized_output = True
244+
if optype == 'QUANTIZE':
245+
out_tensor = tflite_g.Tensors(op.Outputs(0))
246+
quant = out_tensor.Quantization()
247+
has_prequantized_output = False
248+
if quant is not None and not quant.ScaleIsNone() and not quant.ZeroPointIsNone():
249+
attr['scale'] = quant.ScaleAsNumpy().tolist()
250+
attr['zero_point'] = quant.ZeroPointAsNumpy().tolist()
251+
attr['quantized_dimension'] = quant.QuantizedDimension()
252+
elif optype == 'DEQUANTIZE':
253+
in_tensor = tflite_g.Tensors(op.Inputs(0))
254+
quant = in_tensor.Quantization()
255+
wants_dequantized_input = False
256+
if quant is not None and not quant.ScaleIsNone() and not quant.ZeroPointIsNone():
257+
attr['scale'] = quant.ScaleAsNumpy().tolist()
258+
attr['zero_point'] = quant.ZeroPointAsNumpy().tolist()
259+
attr['quantized_dimension'] = quant.QuantizedDimension()
260+
if option_class is not None:
261+
options = option_class()
262+
options.Init(op.BuiltinOptions().Bytes, op.BuiltinOptions().Pos)
263+
# All flatbuffer objects have these properties.
264+
block_list = [options_type_name + 'BufferHasIdentifier', 'Init', 'GetRootAs' + options_type_name]
265+
# The rest of the properties of the options class provide its attribute names
266+
attr_names = {opt for opt in dir(options) if not opt.startswith('_') and opt not in block_list}
267+
for a in list(attr_names):
268+
# Flatbufffer list properties have 3 functions: *Length, *IsNone, and *AsNumpy
269+
if a + 'Length' in attr_names:
270+
attr_names.remove(a + 'Length')
271+
attr_names.remove(a + 'IsNone')
272+
attr_names.remove(a)
273+
for a in attr_names:
274+
if a.endswith('AsNumpy'):
275+
value = getattr(options, a)().tolist()
276+
a = a[:-len('AsNumpy')]
277+
else:
278+
# For enums we use a string with the value name, not enum index
279+
value = getattr(options, a)()
280+
if a in NODE_ATTR_NAME_TO_ENUM_TYPE:
281+
value = lookup_enum(value, NODE_ATTR_NAME_TO_ENUM_TYPE[a])
282+
elif a in FUNCTION_ATTRS:
283+
value = model.Subgraphs(value).Name().decode()
284+
attr_cnt[a] += 1
285+
attr[proper_to_snake_case(a)] = value
286+
input_names = [tensor_names[op.Inputs(i)] for i in range(op.InputsLength()) if op.Inputs(i) != -1]
287+
if wants_dequantized_input:
288+
input_names = [get_dequant(inp) for inp in input_names]
289+
output_names = [tensor_names[op.Outputs(i)] for i in range(op.OutputsLength()) if op.Outputs(i) != -1]
290+
if has_prequantized_output:
291+
output_names = [get_prequant(out) for out in output_names]
292+
onnx_node = helper.make_node("TFL_" + optype, input_names, output_names, name=output_names[0], **attr)
293+
onnx_nodes.append(onnx_node)
294+
295+
inputs = [tensor_names[tflite_g.Inputs(i)] for i in range(tflite_g.InputsLength())]
296+
outputs = [tensor_names[tflite_g.Outputs(i)] for i in range(tflite_g.OutputsLength())]
297+
# TODO: Allow input/outputs to be overridden
298+
299+
for inp in inputs:
300+
onnx_node = helper.make_node("Placeholder", [], outputs=[inp], name=inp)
301+
onnx_nodes.append(onnx_node)
302+
303+
graph_name = (tflite_g.Name() or b'tflite graph').decode()
304+
return onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, graph_name

0 commit comments

Comments
 (0)