Skip to content

Commit 4fc6620

Browse files
Improve keras output order detection (#1489)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 591a99b commit 4fc6620

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tf2onnx/convert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,11 +334,14 @@ def wrap_call(*args, training=False, **kwargs):
334334

335335
initialized_tables = None
336336
tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names)
337+
reverse_lookup = {v: k for k, v in tensors_to_rename.items()}
337338

338339
if model.output_names:
339340
# model.output_names is an optional field of Keras models indicating output order. It is None if unused.
340-
reverse_lookup = {v: k for k, v in tensors_to_rename.items()}
341341
output_names = [reverse_lookup[out] for out in model.output_names]
342+
elif isinstance(concrete_func.structured_outputs, dict):
343+
# Other models specify output order using the key order of structured_outputs
344+
output_names = [reverse_lookup[out] for out in concrete_func.structured_outputs.keys()]
342345

343346
with tf.device("/cpu:0"):
344347
frozen_graph = tf_loader.from_function(concrete_func, input_names, output_names, large_model=large_model)

0 commit comments

Comments
 (0)