@@ -310,7 +310,7 @@ def from_checkpoint(model_path, input_names, output_names):
310
310
return frozen_graph , input_names , output_names
311
311
312
312
313
- def _from_saved_model_v1 (sess , model_path , input_names , output_names , tag , signature_names ):
313
+ def _from_saved_model_v1 (sess , model_path , input_names , output_names , tag , signature_names , use_graph_names ):
314
314
"""Load tensorflow graph from saved_model."""
315
315
316
316
wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
@@ -345,22 +345,25 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
345
345
# TF1.12 changed the api
346
346
get_signature_def = lambda meta_graph_def , k : meta_graph_def .signature_def [k ]
347
347
348
+ tensors_to_rename = {}
348
349
if input_names is None :
349
350
input_names = []
350
351
for k in signatures :
351
352
inputs_tensor_info = get_signature_def (imported , k ).inputs
352
- for _ , input_tensor in inputs_tensor_info .items ():
353
+ for structured_name , input_tensor in inputs_tensor_info .items ():
353
354
if input_tensor .name not in input_names :
354
355
input_names .append (input_tensor .name )
355
- tensors_to_rename = {}
356
+ if not use_graph_names :
357
+ tensors_to_rename [input_tensor .name ] = structured_name
356
358
if output_names is None :
357
359
output_names = []
358
360
for k in signatures :
359
361
outputs_tensor_info = get_signature_def (imported , k ).outputs
360
362
for structured_name , output_tensor in outputs_tensor_info .items ():
361
363
if output_tensor .name not in output_names :
362
364
output_names .append (output_tensor .name )
363
- tensors_to_rename [output_tensor .name ] = structured_name
365
+ if not use_graph_names :
366
+ tensors_to_rename [output_tensor .name ] = structured_name
364
367
frozen_graph , initialized_tables = \
365
368
freeze_session (sess , input_names = input_names , output_names = output_names , get_tables = True )
366
369
return frozen_graph , input_names , output_names , initialized_tables , tensors_to_rename
@@ -447,7 +450,7 @@ def _restore_captured_resources(concrete_func, graph_captures_copy, func_capture
447
450
448
451
449
452
def _from_saved_model_v2 (model_path , input_names , output_names , tag , signature_def ,
450
- concrete_function_index , large_model ):
453
+ concrete_function_index , large_model , use_graph_names ):
451
454
"""Load tensorflow graph from saved_model."""
452
455
453
456
wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
@@ -495,18 +498,16 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
495
498
graph_captures = concrete_func .graph ._captures # pylint: disable=protected-access
496
499
captured_inputs = [t_name .name for _ , t_name in graph_captures .values ()]
497
500
inputs = [inp for inp in inputs if inp not in captured_inputs ]
498
- if concrete_func .structured_input_signature is not None :
499
- args , kwargs = concrete_func .structured_input_signature
500
- structured_inputs = [t .name for t in args if isinstance (t , tf .TensorSpec )] + sorted (kwargs .keys ())
501
- structured_inputs = set (inp + ":0" for inp in structured_inputs )
502
- if any (inp in structured_inputs for inp in inputs ):
503
- inputs = [inp for inp in inputs if inp in structured_inputs ]
501
+ if concrete_func .structured_input_signature is not None and not use_graph_names :
502
+ flat_structured_inp = tf .nest .flatten (concrete_func .structured_input_signature )
503
+ structured_inputs = [t .name for t in flat_structured_inp if isinstance (t , tf .TensorSpec )]
504
+ tensors_to_rename .update (zip (inputs , structured_inputs ))
504
505
else :
505
506
inputs = input_names
506
507
507
508
if output_names is None :
508
509
outputs = [tensor .name for tensor in concrete_func .outputs if tensor .dtype != tf .dtypes .resource ]
509
- if isinstance (concrete_func .structured_outputs , dict ):
510
+ if isinstance (concrete_func .structured_outputs , dict ) and not use_graph_names :
510
511
# outputs are sorted, sort structured_outputs the same way
511
512
structured_outputs = sorted (concrete_func .structured_outputs .keys ())
512
513
tensors_to_rename .update (zip (outputs , structured_outputs ))
@@ -515,7 +516,6 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
515
516
logger .info ("Output names: %r" , outputs )
516
517
else :
517
518
outputs = output_names
518
- logger .info ("Outputs not left as None; will use provided names not structured output names." )
519
519
520
520
frozen_graph , initialized_tables = from_trackable (imported , concrete_func , inputs , outputs , large_model )
521
521
@@ -524,7 +524,8 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
524
524
525
525
def from_saved_model (model_path , input_names , output_names , tag = None ,
526
526
signatures = None , concrete_function = None , large_model = False ,
527
- return_concrete_func = False , return_initialized_tables = False , return_tensors_to_rename = False ):
527
+ return_concrete_func = False , return_initialized_tables = False ,
528
+ return_tensors_to_rename = False , use_graph_names = False ):
528
529
"""Load tensorflow graph from saved_model."""
529
530
if signatures is None :
530
531
signatures = []
@@ -533,7 +534,7 @@ def from_saved_model(model_path, input_names, output_names, tag=None,
533
534
if is_tf2 ():
534
535
frozen_graph , input_names , output_names , concrete_func , imported , initialized_tables , tensors_to_rename = \
535
536
_from_saved_model_v2 (model_path , input_names , output_names ,
536
- tag , signatures , concrete_function , large_model )
537
+ tag , signatures , concrete_function , large_model , use_graph_names )
537
538
result = [frozen_graph , input_names , output_names ]
538
539
if return_concrete_func :
539
540
result += [concrete_func , imported ]
@@ -544,7 +545,7 @@ def from_saved_model(model_path, input_names, output_names, tag=None,
544
545
else :
545
546
with tf_session () as sess :
546
547
frozen_graph , input_names , output_names , initialized_tables , tensors_to_rename = \
547
- _from_saved_model_v1 (sess , model_path , input_names , output_names , tag , signatures )
548
+ _from_saved_model_v1 (sess , model_path , input_names , output_names , tag , signatures , use_graph_names )
548
549
result = [frozen_graph , input_names , output_names ]
549
550
if return_initialized_tables :
550
551
result += [initialized_tables ]
0 commit comments