Skip to content

Commit 43d21a7

Browse files
Fixed bug in shapes for tfltie postprocess (#1312)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent d480f11 commit 43d21a7

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

tf2onnx/tflite_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,11 @@ def get_prequant(tensor_name):
294294
if wants_dequantized_input:
295295
input_names = [get_dequant(inp) for inp in input_names]
296296
output_names = [tensor_names[op.Outputs(i)] for i in range(op.OutputsLength()) if op.Outputs(i) != -1]
297+
if optype == "TFLite_Detection_PostProcess":
298+
# There's a bug in tflite for the output shapes of this op
299+
for out, shape in zip(output_names, [[-1, -1, 4], [-1, -1], [-1, -1], [-1]]):
300+
if len(output_shapes[out]) != len(shape):
301+
output_shapes[out] = shape
297302
if has_prequantized_output:
298303
output_names = [get_prequant(out) for out in output_names]
299304
onnx_node = helper.make_node("TFL_" + optype, input_names, output_names, name=output_names[0], **attr)

0 commit comments

Comments
 (0)