Skip to content

Commit 7ec5f2c

Browse files
committed
add tf_optimize to tf2onnx since apps are using it
1 parent 4767711 commit 7ec5f2c

File tree

5 files changed

+81
-69
lines changed

5 files changed

+81
-69
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
146146
tf_reset_default_graph()
147147
with tf_session() as sess:
148148
tf.import_graph_def(graph_def, name='')
149-
input_tensors = {i: sess.graph.get_tensor_by_name(i) for i in list(feed_dict.keys())}
150-
output_tensors = {i: sess.graph.get_tensor_by_name(i) for i in output_names_with_port}
151-
graph_def = tf_optimize(input_tensors, output_tensors, graph_def, fold_constant=constant_fold)
149+
graph_def = tf_optimize(list(feed_dict.keys()), output_names_with_port,
150+
graph_def, fold_constant=constant_fold)
152151

153152
tf_reset_default_graph()
154153
with tf_session() as sess:

tests/test_tf_shape_inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ def _run_test_case(self, input_names_with_port, output_names_with_port):
4545
tf.import_graph_def(graph_def, name='')
4646

4747
# optimize graph
48-
input_tensors = {i: sess.graph.get_tensor_by_name(i) for i in input_names_with_port}
49-
output_tensors = {i: sess.graph.get_tensor_by_name(i) for i in output_names_with_port}
50-
graph_def = tf_optimize(input_tensors, output_tensors, sess.graph_def, True)
48+
graph_def = tf_optimize(input_names_with_port, output_names_with_port, sess.graph_def, True)
5149

5250
with tf_session() as sess:
5351
if self.config.is_debug_mode:

tf2onnx/convert.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def main():
135135
model_path = args.keras
136136

137137
if args.verbose:
138-
logger.info("inputs: %s", inputs.keys())
139-
logger.info("outputs: %s", outputs.keys())
138+
logger.info("inputs: %s", inputs)
139+
logger.info("outputs: %s", outputs)
140140

141141
with tf.Graph().as_default() as tf_graph:
142142
tf.import_graph_def(graph_def, name='')
@@ -148,8 +148,8 @@ def main():
148148
custom_op_handlers=custom_ops,
149149
extra_opset=extra_opset,
150150
shape_override=args.shape_override,
151-
input_names=list(inputs.keys()),
152-
output_names=list(outputs.keys()),
151+
input_names=inputs,
152+
output_names=outputs,
153153
inputs_as_nchw=args.inputs_as_nchw)
154154

155155
onnx_graph = optimizer.optimize_graph(g)

tf2onnx/tf_loader.py

Lines changed: 68 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,28 @@ def not_implemented_tf_placeholder(*args, **kwargs):
8383
extract_sub_graph = tf.graph_util.extract_sub_graph
8484

8585

86+
def inputs_without_resource(sess, input_names):
87+
try:
88+
new_input_names = []
89+
for n in input_names:
90+
t = sess.graph.get_tensor_by_name(n)
91+
if t.dtype != tf.dtypes.resource:
92+
new_input_names.append(n)
93+
input_names = new_input_names
94+
except: # pylint: disable=bare-except
95+
pass
96+
return input_names
97+
98+
8699
def from_function(func, input_names, output_names):
87100
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
88101
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
89-
# output_tensors = {i.name: i for i in frozen_func.outputs}
102+
# output_names = [i.name for i in frozen_func.outputs]
90103
tf_reset_default_graph()
91104
with tf_session() as sess:
92105
tf.import_graph_def(graph_def, name='')
93-
input_tensors = {i: sess.graph.get_tensor_by_name(i) for i in input_names}
94-
output_tensors = {i: sess.graph.get_tensor_by_name(i) for i in output_names}
95-
graph_def = tf_optimize(input_tensors, output_tensors, graph_def)
106+
input_names = inputs_without_resource(sess, input_names)
107+
graph_def = tf_optimize(input_names, output_names, graph_def)
96108
return graph_def
97109

98110

@@ -101,7 +113,6 @@ def freeze_session(sess, input_names=None, output_names=None):
101113
output_node_names = [i.split(':')[:-1][0] for i in output_names]
102114
keep_var_names = [i.split(':')[:-1][0] for i in input_names]
103115
with sess.graph.as_default():
104-
# freeze_var_names = list(set(v.op.name for v in tf_global_variables()).difference(keep_var_names or []))
105116
output_node_names = output_node_names or []
106117
output_node_names += [v.op.name for v in tf_global_variables()]
107118
output_node_names += keep_var_names
@@ -135,16 +146,16 @@ def from_graphdef(model_path, input_names, output_names):
135146
with tf_gfile.GFile(model_path, 'rb') as f:
136147
graph_def.ParseFromString(f.read())
137148
tf.import_graph_def(graph_def, name='')
149+
input_names = inputs_without_resource(sess, input_names)
138150
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
139151
input_names = remove_redundant_inputs(frozen_graph, input_names)
140-
inputs = {i: sess.graph.get_tensor_by_name(i) for i in input_names}
141-
outputs = {i: sess.graph.get_tensor_by_name(i) for i in output_names}
142152

143153
tf_reset_default_graph()
144154
with tf_session() as sess:
145-
frozen_graph = tf_optimize(inputs, outputs, frozen_graph)
155+
input_names = inputs_without_resource(sess, input_names)
156+
frozen_graph = tf_optimize(input_names, output_names, frozen_graph)
146157
tf_reset_default_graph()
147-
return frozen_graph, inputs, outputs
158+
return frozen_graph, input_names, output_names
148159

149160

150161
def from_checkpoint(model_path, input_names, output_names):
@@ -156,16 +167,15 @@ def from_checkpoint(model_path, input_names, output_names):
156167
saver = tf_import_meta_graph(model_path, clear_devices=True)
157168
# restore from model_path minus the ".meta"
158169
saver.restore(sess, model_path[:-5])
170+
input_names = inputs_without_resource(sess, input_names)
159171
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
160172
input_names = remove_redundant_inputs(frozen_graph, input_names)
161-
inputs = {i: sess.graph.get_tensor_by_name(i) for i in input_names}
162-
outputs = {i: sess.graph.get_tensor_by_name(i) for i in output_names}
163173

164174
tf_reset_default_graph()
165175
with tf_session() as sess:
166-
frozen_graph = tf_optimize(inputs, outputs, frozen_graph)
176+
frozen_graph = tf_optimize(input_names, output_names, frozen_graph)
167177
tf_reset_default_graph()
168-
return frozen_graph, inputs, outputs
178+
return frozen_graph, input_names, output_names
169179

170180

171181
def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures):
@@ -197,16 +207,15 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures
197207
for _, output_tensor in outputs_tensor_info.items():
198208
outputs[output_tensor.name] = sess.graph.get_tensor_by_name(output_tensor.name)
199209

200-
frozen_graph = freeze_session(sess, input_names=list(inputs.keys()), output_names=list(outputs.keys()))
201-
return frozen_graph, inputs, outputs
210+
input_names = list(inputs.keys())
211+
input_names = inputs_without_resource(sess, input_names)
212+
output_names = list(outputs.keys())
213+
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
214+
return frozen_graph, input_names, output_names
202215

203216

204217
def _from_saved_model_v2(model_path, input_names, output_names, signatures):
205218
"""Load tensorflow graph from saved_model."""
206-
# make sure we start with clean default graph
207-
inputs = {}
208-
outputs = {}
209-
210219
imported = tf.saved_model.load(model_path) # pylint: disable=no-value-for-parameter
211220

212221
# f = meta_graph_def.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
@@ -217,34 +226,33 @@ def _from_saved_model_v2(model_path, input_names, output_names, signatures):
217226
signatures.append(k)
218227
for k in signatures:
219228
concrete_func = imported.signatures[k]
220-
inputs = {input_tensor.name: input_tensor for input_tensor in concrete_func.inputs}
221-
outputs = {output_tensor.name: output_tensor for output_tensor in concrete_func.outputs}
229+
input_names = [input_tensor.name for input_tensor in concrete_func.inputs]
230+
output_names = [output_tensor.name for output_tensor in concrete_func.outputs]
222231

223-
frozen_graph = from_function(concrete_func, list(inputs.keys()), list(outputs.keys()))
224-
return frozen_graph, inputs, outputs
232+
input_names = inputs_without_resource(sess, input_names)
233+
frozen_graph = from_function(concrete_func, input_names, output_names)
234+
return frozen_graph, input_names, output_names
225235

226236

227237
def from_saved_model(model_path, input_names, output_names, signatures=None):
228238
"""Load tensorflow graph from saved_model."""
229-
# make sure we start with clean default graph
230239
if signatures is None:
231240
signatures = []
232241
tf_reset_default_graph()
233242
if is_tf2():
234-
frozen_graph, inputs, outputs = \
243+
frozen_graph, input_names, output_names = \
235244
_from_saved_model_v2(model_path, input_names, output_names, signatures)
236-
inputs = {k: v for k, v in inputs.items() if v.dtype != tf.dtypes.resource}
237245
else:
238246
with tf_session() as sess:
239-
frozen_graph, inputs, outputs = \
247+
frozen_graph, input_names, output_names = \
240248
_from_saved_model_v1(sess, model_path, input_names, output_names, signatures)
241249

242250
if len(signatures) > 1:
243251
logger.warning("found multiple signatures %s in saved_model, pass --signature_def in command line",
244252
signatures)
245253

246254
tf_reset_default_graph()
247-
return frozen_graph, inputs, outputs
255+
return frozen_graph, input_names, output_names
248256

249257

250258
def from_keras(model_path, input_names, output_names):
@@ -261,34 +269,39 @@ def from_keras(model_path, input_names, output_names):
261269

262270
function = _saving_utils.trace_model_call(keras_model)
263271
concrete_func = function.get_concrete_function()
264-
inputs = {input_tensor.name: input_tensor for input_tensor in concrete_func.inputs if
265-
input_tensor.name in input_names}
266-
outputs = {output_tensor.name: output_tensor for output_tensor in concrete_func.outputs if
267-
output_tensor.name in output_names}
268-
frozen_graph = from_function(concrete_func, list(inputs.keys()), list(outputs.keys()))
272+
# allow to pass inputs and outputs from caller if we don't want all of them
273+
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
274+
if input_names is None or input_tensor.name in input_names]
275+
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
276+
if output_names is None or output_tensor.name in output_names]
277+
frozen_graph = from_function(concrete_func, input_names, output_names)
269278
else:
270279
# Handles Keras when Eager mode is disabled.
271280
_keras.backend.clear_session()
272281
_keras.backend.set_learning_phase(False)
273282
keras_model = _keras.models.load_model(model_path, custom_objects)
283+
# allow to pass inputs and outputs from caller if we don't want all of them
284+
if input_names:
285+
input_names = [i for i in keras_model.inputs if i in input_names]
286+
else:
287+
input_names = keras_model.inputs
288+
if output_names:
289+
output_names = [i for i in keras_model.outputs if i in output_names]
290+
else:
291+
output_names = keras_model.outputs
274292
sess = _keras.backend.get_session()
275293
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
276-
inputs = {i: sess.graph.get_tensor_by_name(i) for i in keras_model.inputs}
277-
outputs = {i: sess.graph.get_tensor_by_name(i) for i in keras_model.outputs}
278294
tf_reset_default_graph()
279295
with tf_session() as sess:
280-
frozen_graph = tf_optimize(inputs, outputs, frozen_graph)
296+
frozen_graph = tf_optimize(input_names, output_names, frozen_graph)
281297
tf_reset_default_graph()
282-
return frozen_graph, inputs, outputs
298+
return frozen_graph, input_names, output_names
283299

284300

285-
def tf_optimize_grappler(input_tensors, output_tensors, graph_def, fold_constant=None):
301+
def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=None):
286302
from tensorflow.core.protobuf import meta_graph_pb2 as meta_graph_pb2, config_pb2, rewriter_config_pb2
287303
from tensorflow.python.grappler import tf_optimizer as tf_opt
288304

289-
# don't use resource type as input
290-
output_tensors = list(output_tensors)
291-
292305
config = config_pb2.ConfigProto()
293306
rewrite_options = config.graph_options.rewrite_options
294307
config.graph_options.infer_shapes = True
@@ -300,34 +313,27 @@ def tf_optimize_grappler(input_tensors, output_tensors, graph_def, fold_constant
300313
]
301314
meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph_def)
302315
fetch_collection = meta_graph_pb2.CollectionDef()
303-
for t in list(input_tensors) + output_tensors:
316+
for t in input_names + output_names:
304317
fetch_collection.node_list.value.append(t)
305318
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
306319
graph_def = tf_opt.OptimizeGraph(config, meta_graph)
307320
return graph_def
308321

309322

310-
def tf_optimize(input_tensors, output_tensors, graph_def, fold_constant=True):
323+
def tf_optimize(input_names, output_names, graph_def, fold_constant=True):
311324
"""Extract inference subgraph and optimize graph."""
312-
assert isinstance(input_tensors, dict)
313-
assert isinstance(output_tensors, dict)
314-
try:
315-
input_tensors = {
316-
name: tensor for name, tensor in input_tensors.items()
317-
if tensor.dtype != tf.dtypes.resource
318-
}
319-
except: # pylint: disable=bare-except
320-
pass
325+
assert isinstance(input_names, list)
326+
assert isinstance(output_names, list)
321327

322328
# TODO: is this needed ?
323-
needed_names = [utils.node_name(i) for i in input_tensors.keys()] + \
324-
[utils.node_name(i) for i in output_tensors.keys()]
329+
needed_names = [utils.node_name(i) for i in input_names] + \
330+
[utils.node_name(i) for i in output_names]
325331
graph_def = extract_sub_graph(graph_def, needed_names)
326332

327333
if fold_constant:
328334
want_grappler = is_tf2() or LooseVersion(tf.__version__) >= "1.15"
329335
if want_grappler:
330-
graph_def = tf_optimize_grappler(input_tensors, output_tensors, graph_def, fold_constant)
336+
graph_def = tf_optimize_grappler(input_names, output_names, graph_def, fold_constant)
331337
else:
332338
# the older transform path
333339
from tensorflow.tools.graph_transforms import TransformGraph # pylint: disable=redefined-outer-name
@@ -341,7 +347,7 @@ def tf_optimize(input_tensors, output_tensors, graph_def, fold_constant=True):
341347
"fold_batch_norms",
342348
"fold_old_batch_norms",
343349
])
344-
graph_def = TransformGraph(graph_def, input_tensors.keys(), output_tensors.keys(), transforms)
350+
graph_def = TransformGraph(graph_def, input_names, output_names, transforms)
345351

346352
return graph_def
347353

@@ -367,7 +373,6 @@ def is_function(g):
367373
return 'tensorflow.python.framework.func_graph.FuncGraph' in str(type(g))
368374
return False
369375

370-
371376
_FUNCTIONS = {}
372377

373378

@@ -387,7 +392,12 @@ def toposort(data):
387392
fdef = fdef.definition
388393
if input_shapes and len(fdef.signature.input_arg) < len(input_shapes):
389394
input_shapes = input_shapes[:len(fdef.signature.input_arg)]
390-
func = function_def_to_graph(fdef, input_shapes=input_shapes)
395+
try:
396+
func = function_def_to_graph(fdef, input_shapes=input_shapes)
397+
except: # pylint: disable=bare-except
398+
# if there is a missmatch between caller and function use the functions shape
399+
logger.warning("shape missmatch between caller and function: %s", k)
400+
func = function_def_to_graph(fdef)
391401
_FUNCTIONS[k] = func
392402
_, _, _, _, _, tfunctions = tflist_to_onnx(func, {})
393403
functions.update(tfunctions)

tf2onnx/tfonnx.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
361361
Return:
362362
onnx graph
363363
"""
364-
# TODO: remove verbose argument in future release
365364
if verbose:
366365
logger.warning("Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead.")
367366
del verbose
@@ -507,3 +506,9 @@ def compat_handler(ctx, node, **kwargs):
507506
"\tonnx unmapped: {}".format(op_cnt, attr_cnt, mapped_op, unmapped_op))
508507

509508
return g
509+
510+
511+
def tf_optimize(input_names, output_names, graph_def, fold_constant=True):
512+
"""optimize tensorflow graph. This is in tf_loader but some apps call this
513+
so we proxy into tf_loader to keep them working."""
514+
return tf2onnx.tf_loader.tf_optimize(input_names, output_names, graph_def, fold_constant)

0 commit comments

Comments
 (0)