|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +""" |
| 5 | +tf2onnx.tflite_utils - utilities for parsing tflite files into onnx graph |
| 6 | +""" |
| 7 | + |
| 8 | +import collections |
| 9 | +import importlib |
| 10 | + |
| 11 | +from onnx import helper, onnx_pb, numpy_helper |
| 12 | +from tensorflow.core.framework import types_pb2, tensor_pb2 |
| 13 | +from tensorflow.python.framework import tensor_util |
| 14 | +from tflite.TensorType import TensorType as TFLiteTensorType |
| 15 | +from tflite.Model import Model |
| 16 | + |
| 17 | + |
| 18 | +TFLITE_TO_ONNX_DTYPE = { |
| 19 | + TFLiteTensorType.FLOAT32: onnx_pb.TensorProto.FLOAT, |
| 20 | + TFLiteTensorType.FLOAT16: onnx_pb.TensorProto.FLOAT16, |
| 21 | + TFLiteTensorType.INT32: onnx_pb.TensorProto.INT32, |
| 22 | + TFLiteTensorType.UINT8: onnx_pb.TensorProto.UINT8, |
| 23 | + TFLiteTensorType.INT64: onnx_pb.TensorProto.INT64, |
| 24 | + TFLiteTensorType.STRING: onnx_pb.TensorProto.STRING, |
| 25 | + TFLiteTensorType.BOOL: onnx_pb.TensorProto.BOOL, |
| 26 | + TFLiteTensorType.INT16: onnx_pb.TensorProto.INT16, |
| 27 | + TFLiteTensorType.COMPLEX64: onnx_pb.TensorProto.COMPLEX64, |
| 28 | + TFLiteTensorType.INT8: onnx_pb.TensorProto.INT8, |
| 29 | + TFLiteTensorType.FLOAT64: onnx_pb.TensorProto.DOUBLE, |
| 30 | + TFLiteTensorType.COMPLEX128: onnx_pb.TensorProto.COMPLEX128, |
| 31 | + TFLiteTensorType.UINT64: onnx_pb.TensorProto.UINT64, |
| 32 | +} |
| 33 | + |
| 34 | + |
| 35 | +TFLITE_TO_TF_DTYPE = { |
| 36 | + TFLiteTensorType.FLOAT32: types_pb2.DT_FLOAT, |
| 37 | + TFLiteTensorType.FLOAT16: types_pb2.DT_HALF, |
| 38 | + TFLiteTensorType.INT32: types_pb2.DT_INT32, |
| 39 | + TFLiteTensorType.UINT8: types_pb2.DT_UINT8, |
| 40 | + TFLiteTensorType.INT64: types_pb2.DT_INT64, |
| 41 | + TFLiteTensorType.STRING: types_pb2.DT_STRING, |
| 42 | + TFLiteTensorType.BOOL: types_pb2.DT_BOOL, |
| 43 | + TFLiteTensorType.INT16: types_pb2.DT_INT16, |
| 44 | + TFLiteTensorType.COMPLEX64: types_pb2.DT_COMPLEX64, |
| 45 | + TFLiteTensorType.INT8: types_pb2.DT_INT8, |
| 46 | + TFLiteTensorType.FLOAT64: types_pb2.DT_DOUBLE, |
| 47 | + TFLiteTensorType.COMPLEX128: types_pb2.DT_COMPLEX128, |
| 48 | + TFLiteTensorType.UINT64: types_pb2.DT_UINT64, |
| 49 | +} |
| 50 | + |
| 51 | + |
| 52 | +def map_tflite_dtype_to_onnx(dtype): |
| 53 | + return TFLITE_TO_ONNX_DTYPE[dtype] |
| 54 | + |
| 55 | + |
| 56 | +def map_tflite_dtype_to_tf(dtype): |
| 57 | + return TFLITE_TO_TF_DTYPE[dtype] |
| 58 | + |
| 59 | + |
| 60 | +# The tflite schema uses snake case, but the python bindings use proper case |
| 61 | +def snake_to_proper_case(name): |
| 62 | + return ''.join(n.capitalize() for n in name.split('_')) |
| 63 | + |
| 64 | + |
| 65 | +def proper_to_snake_case(name): |
| 66 | + res = '' |
| 67 | + for c in name: |
| 68 | + if c.isupper() and res: |
| 69 | + res += '_' |
| 70 | + res += c.lower() |
| 71 | + return res |
| 72 | + |
| 73 | +# Pulled from the tflite schema.fbs file. Needed to decode enum numbers into strings. |
| 74 | +NODE_ATTR_NAME_TO_ENUM_TYPE = { |
| 75 | + 'fused_activation_function': 'ActivationFunctionType', |
| 76 | + 'padding': 'Padding', |
| 77 | + 'type': 'LSHProjectionType', |
| 78 | + 'weights_format': 'FullyConnectedOptionsWeightsFormat', |
| 79 | + 'kernel_type': 'LSTMKernelType', |
| 80 | + 'combiner': 'CombinerType', |
| 81 | + 'in_data_type': 'TensorType', |
| 82 | + 'out_data_type': 'TensorType', |
| 83 | + 'output_type': 'TensorType', |
| 84 | + 'out_type': 'TensorType', |
| 85 | + 'mode': 'MirrorPadMode', |
| 86 | + 'idx_out_type': 'TensorType', |
| 87 | +} |
| 88 | +NODE_ATTR_NAME_TO_ENUM_TYPE = {snake_to_proper_case(key): value for key, value in NODE_ATTR_NAME_TO_ENUM_TYPE.items()} |
| 89 | + |
| 90 | +# Pulled from the tflite schema.fbs file. |
| 91 | +FUNCTION_ATTRS = ['then_subgraph_index', 'else_subgraph_index', 'cond_subgraph_index', |
| 92 | + 'body_subgraph_index', 'subgraph'] |
| 93 | +FUNCTION_ATTRS = [snake_to_proper_case(attr) for attr in FUNCTION_ATTRS] |
| 94 | + |
| 95 | + |
| 96 | +enum_cache = {} |
| 97 | +def lookup_enum(idx, enum_name): |
| 98 | + """Given the name of a tflite enum class and an index, return a string with the name of the enum value""" |
| 99 | + if enum_name == 'TensorType': |
| 100 | + return map_tflite_dtype_to_onnx(idx) |
| 101 | + if enum_name in enum_cache: |
| 102 | + return enum_cache[enum_name][idx] |
| 103 | + module = importlib.import_module('tflite.' + enum_name) |
| 104 | + enum_class = getattr(module, enum_name) |
| 105 | + idx_to_name = {value: key for key, value in enum_class.__dict__.items() if not key.startswith('_')} |
| 106 | + enum_cache[enum_name] = idx_to_name |
| 107 | + return idx_to_name[idx] |
| 108 | + |
| 109 | + |
| 110 | +def get_options_class(name): |
| 111 | + """Each tflite optype has a flatbuffer Options class (ex: AddOptions). Returns the options class given its name.""" |
| 112 | + if name == "NONE": |
| 113 | + return None |
| 114 | + module = importlib.import_module('tflite.' + name) |
| 115 | + return getattr(module, name) |
| 116 | + |
| 117 | + |
| 118 | +def read_tflite_model(tflite_path): |
| 119 | + """ |
| 120 | + Given the path to a tflite model, returns tuple (tflite_graphs, opcodes_map, model) |
| 121 | + Pass these to parse_tflite_graph |
| 122 | + """ |
| 123 | + with open(tflite_path, 'rb') as f: |
| 124 | + buf = f.read() |
| 125 | + buf = bytearray(buf) |
| 126 | + model = Model.GetRootAsModel(buf, 0) |
| 127 | + # To save space, each op in the model indicates its opcode as an index into the model's opcode map. |
| 128 | + opcodes_map = {} |
| 129 | + for i in range(model.OperatorCodesLength()): |
| 130 | + op_code = model.OperatorCodes(i) |
| 131 | + # TFlite ran out of opcodes since they only used a byte. Old models store opcodes in DeprecatedBuiltinCode. |
| 132 | + # New models put PLACEHOLDER_FOR_GREATER_OP_CODES in this field to signify that BuiltinCode should be used. |
| 133 | + code = lookup_enum(op_code.DeprecatedBuiltinCode(), 'BuiltinOperator') |
| 134 | + if code == 'PLACEHOLDER_FOR_GREATER_OP_CODES': |
| 135 | + code = lookup_enum(op_code.BuiltinCode(), 'BuiltinOperator') |
| 136 | + opcodes_map[i] = code |
| 137 | + tflite_graphs = [model.Subgraphs(i) for i in range(model.SubgraphsLength())] |
| 138 | + return tflite_graphs, opcodes_map, model |
| 139 | + |
| 140 | + |
| 141 | +def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''): |
| 142 | + """ |
| 143 | + Returns a Graph object along with some op count stats. All tflite op types are prefixed with "TFL_". |
| 144 | + Names of graph inputs are optionally prefixed with a string to prevent name conflicts in subgraphs. |
| 145 | + Quantizatized tensors are surrounded with quantize/dequantize ops |
| 146 | + """ |
| 147 | + op_cnt = collections.Counter() |
| 148 | + attr_cnt = collections.Counter() |
| 149 | + onnx_nodes = [] |
| 150 | + output_shapes = {} |
| 151 | + dtypes = {} |
| 152 | + tensor_names = {} |
| 153 | + # Map tensor name to tflite Tensor object so we can fetch quantization info as needed |
| 154 | + name_to_tensor = {} |
| 155 | + # If a node takes a quantized tensor as input, we must add a dequantize op after it. |
| 156 | + # Store a mapping so we only need to make at most one dequantize op per tensor. |
| 157 | + tensor_name_to_dequant_output = {} |
| 158 | + |
| 159 | + # tflite uses generic names (arg0, arg1, etc.) for inputs but full names for other tensors, so |
| 160 | + # prefixing just the inputs should be fine. Other tensors are prefixed when we do inlining. |
| 161 | + input_indices = {tflite_g.Inputs(i) for i in range(tflite_g.InputsLength())} |
| 162 | + |
| 163 | + for i in range(tflite_g.TensorsLength()): |
| 164 | + tensor = tflite_g.Tensors(i) |
| 165 | + name = tensor.Name().decode() |
| 166 | + if i in input_indices: |
| 167 | + name = input_prefix + name |
| 168 | + tensor_names[i] = name |
| 169 | + name_to_tensor[name] = tensor |
| 170 | + |
| 171 | + if tensor.ShapeIsNone(): |
| 172 | + output_shapes[name] = None |
| 173 | + elif tensor.ShapeSignatureIsNone(): |
| 174 | + # The shape signature uses -1 to signify unknown dims. Old models don't have this and use Shape instead. |
| 175 | + output_shapes[name] = tensor.ShapeAsNumpy().tolist() |
| 176 | + else: |
| 177 | + output_shapes[name] = tensor.ShapeSignatureAsNumpy().tolist() |
| 178 | + buf = model.Buffers(tensor.Buffer()) |
| 179 | + dtypes[name] = map_tflite_dtype_to_onnx(tensor.Type()) |
| 180 | + if not buf.DataIsNone(): |
| 181 | + # For const values we use TF to decode the binary data from the buffer |
| 182 | + t = tensor_pb2.TensorProto() |
| 183 | + t.tensor_content = buf.DataAsNumpy().tobytes() |
| 184 | + if output_shapes[name] is None: |
| 185 | + output_shapes[name] = [] |
| 186 | + for d in output_shapes[name]: |
| 187 | + t.tensor_shape.dim.add().size = d |
| 188 | + t.dtype = map_tflite_dtype_to_tf(tensor.Type()) |
| 189 | + np_data = tensor_util.MakeNdarray(t) |
| 190 | + onnx_tensor = numpy_helper.from_array(np_data, name=name) |
| 191 | + onnx_node = helper.make_node("Const", [], outputs=[name], name=name, value=onnx_tensor) |
| 192 | + onnx_nodes.append(onnx_node) |
| 193 | + op_cnt["Const"] += 1 |
| 194 | + |
| 195 | + def get_dequant(tensor_name): |
| 196 | + """Creates a dequantize op for the provided tensor if needed and returns the output of the op, or |
| 197 | + the original tensor name if no dequantization is needed""" |
| 198 | + quant = name_to_tensor[tensor_name].Quantization() |
| 199 | + if quant is None or quant.ScaleIsNone() or quant.ZeroPointIsNone(): |
| 200 | + return tensor_name |
| 201 | + if tensor_name in tensor_name_to_dequant_output: |
| 202 | + return tensor_name_to_dequant_output[tensor_name] |
| 203 | + dequant_name = tensor_name + "_dequant" |
| 204 | + attr = {} |
| 205 | + attr['scale'] = quant.ScaleAsNumpy().tolist() |
| 206 | + attr['zero_point'] = quant.ZeroPointAsNumpy().tolist() |
| 207 | + attr['quantized_dimension'] = quant.QuantizedDimension() |
| 208 | + onnx_node = helper.make_node("TFL_DEQUANTIZE", [tensor_name], [dequant_name], name=dequant_name, **attr) |
| 209 | + onnx_nodes.append(onnx_node) |
| 210 | + tensor_name_to_dequant_output[tensor_name] = dequant_name |
| 211 | + output_shapes[dequant_name] = output_shapes[tensor_name].copy() |
| 212 | + dtypes[dequant_name] = onnx_pb.TensorProto.FLOAT |
| 213 | + return dequant_name |
| 214 | + |
| 215 | + def get_prequant(tensor_name): |
| 216 | + """Called by nodes with the name of the tensor they must output. |
| 217 | + If the output is supposed to be quantized, creates a Quantize op outputting the tensor. |
| 218 | + Returns the name that should be used for the "prequantized" tensor, or the original tensor if no quantization |
| 219 | + is needed""" |
| 220 | + quant = name_to_tensor[tensor_name].Quantization() |
| 221 | + if quant is None or quant.ScaleIsNone() or quant.ZeroPointIsNone(): |
| 222 | + return tensor_name |
| 223 | + prequant_name = tensor_name + "_prequant" |
| 224 | + quantize_name = tensor_name + "_quantize" |
| 225 | + attr = {} |
| 226 | + attr['scale'] = quant.ScaleAsNumpy().tolist() |
| 227 | + attr['zero_point'] = quant.ZeroPointAsNumpy().tolist() |
| 228 | + attr['quantized_dimension'] = quant.QuantizedDimension() |
| 229 | + onnx_node = helper.make_node("TFL_QUANTIZE", [prequant_name], [tensor_name], name=quantize_name, **attr) |
| 230 | + onnx_nodes.append(onnx_node) |
| 231 | + output_shapes[prequant_name] = output_shapes[tensor_name].copy() |
| 232 | + dtypes[prequant_name] = onnx_pb.TensorProto.FLOAT |
| 233 | + return prequant_name |
| 234 | + |
| 235 | + for i in range(tflite_g.OperatorsLength()): |
| 236 | + op = tflite_g.Operators(i) |
| 237 | + optype = opcodes_map[op.OpcodeIndex()] |
| 238 | + op_cnt[optype] += 1 |
| 239 | + attr = {} |
| 240 | + options_type_name = lookup_enum(op.BuiltinOptionsType(), 'BuiltinOptions') |
| 241 | + option_class = get_options_class(options_type_name) |
| 242 | + wants_dequantized_input = True |
| 243 | + has_prequantized_output = True |
| 244 | + if optype == 'QUANTIZE': |
| 245 | + out_tensor = tflite_g.Tensors(op.Outputs(0)) |
| 246 | + quant = out_tensor.Quantization() |
| 247 | + has_prequantized_output = False |
| 248 | + if quant is not None and not quant.ScaleIsNone() and not quant.ZeroPointIsNone(): |
| 249 | + attr['scale'] = quant.ScaleAsNumpy().tolist() |
| 250 | + attr['zero_point'] = quant.ZeroPointAsNumpy().tolist() |
| 251 | + attr['quantized_dimension'] = quant.QuantizedDimension() |
| 252 | + elif optype == 'DEQUANTIZE': |
| 253 | + in_tensor = tflite_g.Tensors(op.Inputs(0)) |
| 254 | + quant = in_tensor.Quantization() |
| 255 | + wants_dequantized_input = False |
| 256 | + if quant is not None and not quant.ScaleIsNone() and not quant.ZeroPointIsNone(): |
| 257 | + attr['scale'] = quant.ScaleAsNumpy().tolist() |
| 258 | + attr['zero_point'] = quant.ZeroPointAsNumpy().tolist() |
| 259 | + attr['quantized_dimension'] = quant.QuantizedDimension() |
| 260 | + if option_class is not None: |
| 261 | + options = option_class() |
| 262 | + options.Init(op.BuiltinOptions().Bytes, op.BuiltinOptions().Pos) |
| 263 | + # All flatbuffer objects have these properties. |
| 264 | + block_list = [options_type_name + 'BufferHasIdentifier', 'Init', 'GetRootAs' + options_type_name] |
| 265 | + # The rest of the properties of the options class provide its attribute names |
| 266 | + attr_names = {opt for opt in dir(options) if not opt.startswith('_') and opt not in block_list} |
| 267 | + for a in list(attr_names): |
| 268 | + # Flatbufffer list properties have 3 functions: *Length, *IsNone, and *AsNumpy |
| 269 | + if a + 'Length' in attr_names: |
| 270 | + attr_names.remove(a + 'Length') |
| 271 | + attr_names.remove(a + 'IsNone') |
| 272 | + attr_names.remove(a) |
| 273 | + for a in attr_names: |
| 274 | + if a.endswith('AsNumpy'): |
| 275 | + value = getattr(options, a)().tolist() |
| 276 | + a = a[:-len('AsNumpy')] |
| 277 | + else: |
| 278 | + # For enums we use a string with the value name, not enum index |
| 279 | + value = getattr(options, a)() |
| 280 | + if a in NODE_ATTR_NAME_TO_ENUM_TYPE: |
| 281 | + value = lookup_enum(value, NODE_ATTR_NAME_TO_ENUM_TYPE[a]) |
| 282 | + elif a in FUNCTION_ATTRS: |
| 283 | + value = model.Subgraphs(value).Name().decode() |
| 284 | + attr_cnt[a] += 1 |
| 285 | + attr[proper_to_snake_case(a)] = value |
| 286 | + input_names = [tensor_names[op.Inputs(i)] for i in range(op.InputsLength()) if op.Inputs(i) != -1] |
| 287 | + if wants_dequantized_input: |
| 288 | + input_names = [get_dequant(inp) for inp in input_names] |
| 289 | + output_names = [tensor_names[op.Outputs(i)] for i in range(op.OutputsLength()) if op.Outputs(i) != -1] |
| 290 | + if has_prequantized_output: |
| 291 | + output_names = [get_prequant(out) for out in output_names] |
| 292 | + onnx_node = helper.make_node("TFL_" + optype, input_names, output_names, name=output_names[0], **attr) |
| 293 | + onnx_nodes.append(onnx_node) |
| 294 | + |
| 295 | + inputs = [tensor_names[tflite_g.Inputs(i)] for i in range(tflite_g.InputsLength())] |
| 296 | + outputs = [tensor_names[tflite_g.Outputs(i)] for i in range(tflite_g.OutputsLength())] |
| 297 | + # TODO: Allow input/outputs to be overridden |
| 298 | + |
| 299 | + for inp in inputs: |
| 300 | + onnx_node = helper.make_node("Placeholder", [], outputs=[inp], name=inp) |
| 301 | + onnx_nodes.append(onnx_node) |
| 302 | + |
| 303 | + graph_name = (tflite_g.Name() or b'tflite graph').decode() |
| 304 | + return onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, graph_name |
0 commit comments