@@ -214,7 +214,7 @@ def from_function(func, input_names, output_names, large_model=False):
214
214
return graph_def
215
215
216
216
217
- def freeze_session (sess , input_names = None , output_names = None ):
217
+ def freeze_session (sess , input_names = None , output_names = None , get_tables = False ):
218
218
"""Freezes the state of a session into a pruned computation graph."""
219
219
output_node_names = [i .split (':' )[:- 1 ][0 ] for i in output_names ]
220
220
keep_var_names = [i .split (':' )[:- 1 ][0 ] for i in input_names ]
@@ -226,6 +226,19 @@ def freeze_session(sess, input_names=None, output_names=None):
226
226
for node in graph_def .node :
227
227
node .device = ""
228
228
graph_def = convert_variables_to_constants (sess , graph_def , output_node_names )
229
+ table_names , key_dtypes , value_dtypes = get_hash_table_info (graph_def )
230
+ if get_tables :
231
+ initialized_tables = {}
232
+ tf .tables_initializer ().run (session = sess )
233
+ for n , k_dtype , val_dtype in zip (table_names , key_dtypes , value_dtypes ):
234
+ h = lookup_ops .hash_table_v2 (k_dtype , val_dtype , shared_name = n )
235
+ try :
236
+ k , v = lookup_ops .lookup_table_export_v2 (h , k_dtype , val_dtype )
237
+ k , v = sess .run ([k , v ])
238
+ initialized_tables [n ] = (k , v )
239
+ except Exception : # pylint: disable=broad-except
240
+ logger .warning ("Could not initialize table with shared_name = %r" , n )
241
+ return graph_def , initialized_tables
229
242
return graph_def
230
243
231
244
@@ -348,18 +361,8 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
348
361
if output_tensor .name not in output_names :
349
362
output_names .append (output_tensor .name )
350
363
tensors_to_rename [output_tensor .name ] = structured_name
351
- frozen_graph = freeze_session (sess , input_names = input_names , output_names = output_names )
352
- table_names , key_dtypes , value_dtypes = get_hash_table_info (frozen_graph )
353
- initialized_tables = {}
354
- tf .tables_initializer ().run ()
355
- for n , k_dtype , val_dtype in zip (table_names , key_dtypes , value_dtypes ):
356
- h = lookup_ops .hash_table_v2 (k_dtype , val_dtype , shared_name = n )
357
- try :
358
- k , v = lookup_ops .lookup_table_export_v2 (h , k_dtype , val_dtype )
359
- k , v = sess .run ([k , v ])
360
- initialized_tables [n ] = (k , v )
361
- except Exception : # pylint: disable=broad-except
362
- logger .warning ("Could not initialize table with shared_name = %r" , n )
364
+ frozen_graph , initialized_tables = \
365
+ freeze_session (sess , input_names = input_names , output_names = output_names , get_tables = True )
363
366
return frozen_graph , input_names , output_names , initialized_tables , tensors_to_rename
364
367
365
368
0 commit comments