Skip to content

Commit 1d07510

Browse files
Added code for parsing tflite customops (#1281)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 36a3f70 commit 1d07510

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

tf2onnx/flexbuffers.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.flexbuffers - Code for parsing flexbuffers
6+
"""
7+
8+
import struct
9+
10+
11+
class FlexbufferParseException(Exception):
12+
pass
13+
14+
15+
def read_int(buffer, offset, bit_size):
16+
size = 1 << bit_size
17+
format_char = 'bhiq'[bit_size]
18+
return struct.unpack('<' + format_char, buffer[offset:offset+size])[0]
19+
20+
21+
def read_uint(buffer, offset, bit_size):
22+
size = 1 << bit_size
23+
format_char = 'BHIQ'[bit_size]
24+
return struct.unpack('<' + format_char, buffer[offset:offset+size])[0]
25+
26+
27+
def read_float(buffer, offset, bit_size):
28+
if bit_size == 2:
29+
return struct.unpack('<f', buffer[offset:offset+4])[0]
30+
if bit_size == 3:
31+
return struct.unpack('<d', buffer[offset:offset+8])[0]
32+
raise FlexbufferParseException("Invalid bit size for flexbuffer float: %d" % bit_size)
33+
34+
35+
def read_string(buffer, offset, size):
36+
return buffer[offset:offset+size].decode('utf-8')
37+
38+
39+
def read_indirect(buffer, offset, bit_size):
40+
return offset - read_uint(buffer, offset, bit_size)
41+
42+
43+
def read_bytes(buffer, offset, size):
44+
return buffer[offset:offset+size]
45+
46+
47+
def read_array(buffer, offset, length, bit_size, packed_type):
48+
byte_size = 1 << bit_size
49+
arr = []
50+
for i in range(length):
51+
item_offset = offset + (i * byte_size)
52+
arr.append(read_buffer(buffer, item_offset, bit_size, packed_type))
53+
return arr
54+
55+
56+
def read_buffer(buffer, offset, parent_bit_size, packed_type):
57+
"""Recursively decode flatbuffer object into python representation"""
58+
bit_size = packed_type & 3
59+
value_type = packed_type >> 2
60+
byte_size = 1 << bit_size
61+
62+
if value_type == 0x0:
63+
return None
64+
if value_type in [0x1, 0x2, 0x3]:
65+
read_fn = {0x1: read_int, 0x2: read_uint, 0x3: read_float}[value_type]
66+
return read_fn(buffer, offset, parent_bit_size)
67+
if value_type in [0x4, 0x5]:
68+
str_offset = read_indirect(buffer, offset, parent_bit_size)
69+
size = 0
70+
while read_int(buffer, str_offset + size, 0) != 0:
71+
size += 1
72+
return read_string(buffer, str_offset, size)
73+
if value_type == 0x5:
74+
str_offset = read_indirect(buffer, offset, parent_bit_size)
75+
size_byte_size = 1 << bit_size
76+
size = read_uint(buffer, str_offset - size_byte_size, bit_size)
77+
while read_int(buffer, str_offset + size, 0) != 0:
78+
size_byte_size <<= 1
79+
size = read_uint(buffer, str_offset - size_byte_size, bit_size)
80+
return read_string(buffer, str_offset, size)
81+
if value_type in [0x6, 0x7, 0x8]:
82+
read_fn = {0x6: read_int, 0x7: read_uint, 0x8: read_float}[value_type]
83+
data_offset = read_indirect(buffer, offset, parent_bit_size)
84+
return read_fn(buffer, data_offset, bit_size)
85+
if value_type == 0x9:
86+
length = read_uint(buffer, read_indirect(buffer, offset, parent_bit_size) - byte_size, bit_size)
87+
keys_offset = read_indirect(buffer, offset, parent_bit_size) - (byte_size * 3)
88+
keys_vector_offset = read_indirect(buffer, keys_offset, bit_size)
89+
key_byte_size = read_uint(buffer, keys_offset + byte_size, bit_size)
90+
key_bit_size = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4}[key_byte_size]
91+
values_offset = read_indirect(buffer, offset, parent_bit_size)
92+
packed_types_offset = values_offset + length * byte_size
93+
obj = {}
94+
for i in range(length):
95+
key_offset = keys_vector_offset + i * key_byte_size
96+
key = read_buffer(buffer, key_offset, key_bit_size, (0x4 << 2) | key_bit_size)
97+
value_offset = values_offset + i * byte_size
98+
value_packed_type = read_uint(buffer, packed_types_offset + i, 0)
99+
value = read_buffer(buffer, value_offset, bit_size, value_packed_type)
100+
obj[key] = value
101+
return obj
102+
if value_type == 0xa:
103+
length = read_uint(buffer, read_indirect(buffer, offset, parent_bit_size) - byte_size, bit_size)
104+
arr = []
105+
items_offset = read_indirect(buffer, offset, parent_bit_size)
106+
packed_types_offset = items_offset + (length * byte_size)
107+
for i in range(length):
108+
item_offset = items_offset + (i * byte_size)
109+
packed_type = read_uint(buffer, packed_types_offset + i, 0)
110+
arr.append(read_buffer(buffer, item_offset, bit_size, packed_type))
111+
return arr
112+
if value_type in [0xb, 0xc, 0xd, 0xe, 0xf, 0x24]:
113+
length_offset = read_indirect(buffer, offset, parent_bit_size) - byte_size
114+
length = read_uint(buffer, length_offset, bit_size)
115+
item_value_type = value_type - 0xb + 0x1
116+
packed_type = item_value_type << 2
117+
items_offset = read_indirect(buffer, offset, parent_bit_size)
118+
return read_array(buffer, items_offset, length, bit_size, packed_type)
119+
if 0x10 <= value_type <= 0x18:
120+
length = (value_type - 0x10) // 3 + 2
121+
value_type = ((value_type - 0x10) % 3) + 1
122+
packed_type = value_type << 2
123+
items_offset = read_indirect(buffer, offset, parent_bit_size)
124+
return read_array(buffer, items_offset, length, bit_size, packed_type)
125+
if value_type == 0x19:
126+
data_offset = read_indirect(buffer, offset, parent_bit_size)
127+
size_offset = data_offset - byte_size
128+
size = read_uint(buffer, size_offset, bit_size)
129+
return read_bytes(buffer, data_offset, size)
130+
if value_type == 0x1a:
131+
return read_uint(buffer, offset, parent_bit_size) > 0
132+
raise FlexbufferParseException("Invalid flexbuffer value type %r" % value_type)
133+
134+
135+
def read_flexbuffer(buffer):
136+
byte_size = read_uint(buffer, len(buffer) - 1, 0)
137+
bit_size = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4}[byte_size]
138+
packed_type = read_uint(buffer, len(buffer) - 2, 0)
139+
offset = len(buffer) - 2 - byte_size
140+
return read_buffer(buffer, offset, bit_size, packed_type)

tf2onnx/tflite_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from tensorflow.python.framework import tensor_util
1414
from tf2onnx.tflite.TensorType import TensorType as TFLiteTensorType
1515
from tf2onnx.tflite.Model import Model
16+
from tf2onnx.flexbuffers import read_flexbuffer
1617

1718

1819
TFLITE_TO_ONNX_DTYPE = {
@@ -133,6 +134,8 @@ def read_tflite_model(tflite_path):
133134
code = lookup_enum(op_code.DeprecatedBuiltinCode(), 'BuiltinOperator')
134135
if code == 'PLACEHOLDER_FOR_GREATER_OP_CODES':
135136
code = lookup_enum(op_code.BuiltinCode(), 'BuiltinOperator')
137+
if code == 'CUSTOM':
138+
code = op_code.CustomCode().decode()
136139
opcodes_map[i] = code
137140
tflite_graphs = [model.Subgraphs(i) for i in range(model.SubgraphsLength())]
138141
return tflite_graphs, opcodes_map, model
@@ -257,6 +260,10 @@ def get_prequant(tensor_name):
257260
attr['scale'] = quant.ScaleAsNumpy().tolist()
258261
attr['zero_point'] = quant.ZeroPointAsNumpy().tolist()
259262
attr['quantized_dimension'] = quant.QuantizedDimension()
263+
if not op.CustomOptionsIsNone():
264+
custom_ops_format = lookup_enum(op.CustomOptionsFormat(), 'CustomOptionsFormat')
265+
if custom_ops_format == 'FLEXBUFFERS':
266+
attr.update(read_flexbuffer(op.CustomOptionsAsNumpy().tobytes()))
260267
if option_class is not None:
261268
options = option_class()
262269
options.Init(op.BuiltinOptions().Bytes, op.BuiltinOptions().Pos)

0 commit comments

Comments
 (0)