@@ -291,19 +291,65 @@ def tensor_names_from_structed(concrete_func, input_names, output_names):
291
291
return tensors_to_rename
292
292
293
293
294
+ def _rename_duplicate_keras_model_names (model ):
295
+ """
296
+ In very rare cases, keras has a bug where it will give multiple outputs the same name.
297
+ We must edit the model or the TF trace will fail. Returns old_out_names (or None if no edit was made).
298
+ IMPORTANT: model may be edited. Assign model.output_names to old_out_names to restore.
299
+ """
300
+ old_out_names = None
301
+ if model .output_names and len (set (model .output_names )) != len (model .output_names ):
302
+ # In very rare cases, keras has a bug where it will give multiple outputs the same name
303
+ # We must edit the model or the TF trace will fail
304
+ old_out_names = model .output_names
305
+ used_names = set ()
306
+ new_out_names = []
307
+ for name in model .output_names :
308
+ new_name = name
309
+ i = 0
310
+ while new_name in used_names :
311
+ i += 1
312
+ new_name = name + "_" + str (i )
313
+ used_names .add (new_name )
314
+ new_out_names .append (new_name )
315
+ model .output_names = new_out_names
316
+ return old_out_names
317
+
318
+
319
+ def _is_legacy_keras_model (model ):
320
+ """Inspects model class to determine if it is from tf or legacy keras"""
321
+
322
+ logger = logging .getLogger (constants .TF2ONNX_PACKAGE_NAME )
323
+ unknown_type_err = "model is not instance of tf.keras.Model or keras.Model"
324
+ if isinstance (model , tf .keras .Model ):
325
+ return False
326
+ try :
327
+ import keras # pylint: disable=import-outside-toplevel
328
+ if isinstance (model , keras .Model ):
329
+ return True
330
+ logger .warning (unknown_type_err )
331
+ except ImportError :
332
+ logger .warning (unknown_type_err )
333
+ return False
334
+
335
+
294
336
def _from_keras_tf1 (model , input_signature = None , opset = None , custom_ops = None , custom_op_handlers = None ,
295
337
custom_rewriter = None , inputs_as_nchw = None , extra_opset = None , shape_override = None ,
296
338
target = None , large_model = False , output_path = None ):
297
339
"""from_keras for tf 1.15"""
298
-
299
340
input_names = [t .name for t in model .inputs ]
300
341
output_names = [t .name for t in model .outputs ]
342
+ old_out_names = _rename_duplicate_keras_model_names (model )
301
343
tensors_to_rename = dict (zip (input_names , model .input_names ))
302
- if len ( set ( model . output_names )) == len ( model .output_names ):
303
- # In very rare cases, keras has a bug where it will give multiple outputs the same name
304
- tensors_to_rename . update ( zip ( output_names , model .output_names ))
344
+ tensors_to_rename . update ( zip ( output_names , model .output_names ))
345
+ if old_out_names is not None :
346
+ model .output_names = old_out_names
305
347
306
- sess = tf .keras .backend .get_session (model .outputs )
348
+ if _is_legacy_keras_model (model ):
349
+ import keras # pylint: disable=import-outside-toplevel
350
+ sess = keras .backend .get_session ()
351
+ else :
352
+ sess = tf .keras .backend .get_session (model .outputs )
307
353
308
354
with tf .device ("/cpu:0" ):
309
355
frozen_graph , initialized_tables = tf_loader .freeze_session (sess , input_names , output_names , get_tables = True )
@@ -351,6 +397,7 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
351
397
Returns:
352
398
An ONNX model_proto and an external_tensor_storage dict.
353
399
"""
400
+ old_out_names = _rename_duplicate_keras_model_names (model )
354
401
if LooseVersion (tf .__version__ ) < "2.0" :
355
402
return _from_keras_tf1 (model , input_signature , opset , custom_ops , custom_op_handlers , custom_rewriter ,
356
403
inputs_as_nchw , extra_opset , shape_override , target , large_model , output_path )
@@ -370,9 +417,21 @@ def wrap_call(*args, training=False, **kwargs):
370
417
return model_call (* args , ** kwargs )
371
418
model .call = wrap_call
372
419
function = _saving_utils .trace_model_call (model , input_signature )
373
- concrete_func = function .get_concrete_function ()
374
- # Put it back
375
- model .call = model_call
420
+ try :
421
+ # Legacy keras get make TF erroneously enter eager mode when it should be making symbolic tensors
422
+ import tensorflow_core # pylint: disable=import-outside-toplevel
423
+ old_get_learning_phase = tensorflow_core .python .keras .backend .learning_phase
424
+ tensorflow_core .python .keras .backend .learning_phase = \
425
+ tensorflow_core .python .keras .backend .symbolic_learning_phase
426
+ except ImportError :
427
+ old_get_learning_phase = None
428
+ try :
429
+ concrete_func = function .get_concrete_function ()
430
+ finally :
431
+ # Put everything back
432
+ model .call = model_call
433
+ if old_get_learning_phase is not None :
434
+ tensorflow_core .python .keras .backend .learning_phase = old_get_learning_phase
376
435
377
436
# These inputs will be removed during freezing (includes resources, etc.)
378
437
graph_captures = concrete_func .graph ._captures # pylint: disable=protected-access
@@ -392,6 +451,9 @@ def wrap_call(*args, training=False, **kwargs):
392
451
# Other models specify output order using the key order of structured_outputs
393
452
output_names = [reverse_lookup [out ] for out in concrete_func .structured_outputs .keys ()]
394
453
454
+ if old_out_names is not None :
455
+ model .output_names = old_out_names
456
+
395
457
with tf .device ("/cpu:0" ):
396
458
frozen_graph , initialized_tables = \
397
459
tf_loader .from_trackable (model , concrete_func , input_names , output_names , large_model )
0 commit comments