Skip to content

Commit 994a2fb

Browse files
authored
Merge pull request #882 from onnx/gs/api
add tf_optimize back to tf2onnx since apps are using it
2 parents 712518f + e318583 commit 994a2fb

File tree

5 files changed

+79
-75
lines changed

5 files changed

+79
-75
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
146146
tf_reset_default_graph()
147147
with tf_session() as sess:
148148
tf.import_graph_def(graph_def, name='')
149-
input_tensors = {i: sess.graph.get_tensor_by_name(i) for i in list(feed_dict.keys())}
150-
output_tensors = {i: sess.graph.get_tensor_by_name(i) for i in output_names_with_port}
151-
graph_def = tf_optimize(input_tensors, output_tensors, graph_def, fold_constant=constant_fold)
149+
graph_def = tf_optimize(list(feed_dict.keys()), output_names_with_port,
150+
graph_def, fold_constant=constant_fold)
152151

153152
tf_reset_default_graph()
154153
with tf_session() as sess:

tests/test_tf_shape_inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ def _run_test_case(self, input_names_with_port, output_names_with_port):
4545
tf.import_graph_def(graph_def, name='')
4646

4747
# optimize graph
48-
input_tensors = {i: sess.graph.get_tensor_by_name(i) for i in input_names_with_port}
49-
output_tensors = {i: sess.graph.get_tensor_by_name(i) for i in output_names_with_port}
50-
graph_def = tf_optimize(input_tensors, output_tensors, sess.graph_def, True)
48+
graph_def = tf_optimize(input_names_with_port, output_names_with_port, sess.graph_def, True)
5149

5250
with tf_session() as sess:
5351
if self.config.is_debug_mode:

tf2onnx/convert.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def main():
135135
model_path = args.keras
136136

137137
if args.verbose:
138-
logger.info("inputs: %s", inputs.keys())
139-
logger.info("outputs: %s", outputs.keys())
138+
logger.info("inputs: %s", inputs)
139+
logger.info("outputs: %s", outputs)
140140

141141
with tf.Graph().as_default() as tf_graph:
142142
tf.import_graph_def(graph_def, name='')
@@ -148,8 +148,8 @@ def main():
148148
custom_op_handlers=custom_ops,
149149
extra_opset=extra_opset,
150150
shape_override=args.shape_override,
151-
input_names=list(inputs.keys()),
152-
output_names=list(outputs.keys()),
151+
input_names=inputs,
152+
output_names=outputs,
153153
inputs_as_nchw=args.inputs_as_nchw)
154154

155155
onnx_graph = optimizer.optimize_graph(g)

tf2onnx/tf_loader.py

Lines changed: 66 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,28 @@ def not_implemented_tf_placeholder(*args, **kwargs):
8383
extract_sub_graph = tf.graph_util.extract_sub_graph
8484

8585

86+
def inputs_without_resource(sess, input_names):
87+
try:
88+
new_input_names = []
89+
for n in input_names:
90+
t = sess.graph.get_tensor_by_name(n)
91+
if t.dtype != tf.dtypes.resource:
92+
new_input_names.append(n)
93+
input_names = new_input_names
94+
except: # pylint: disable=bare-except
95+
pass
96+
return input_names
97+
98+
8699
def from_function(func, input_names, output_names):
87100
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
88101
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
89-
# output_tensors = {i.name: i for i in frozen_func.outputs}
102+
# output_names = [i.name for i in frozen_func.outputs]
90103
tf_reset_default_graph()
91104
with tf_session() as sess:
92105
tf.import_graph_def(graph_def, name='')
93-
input_tensors = {i: sess.graph.get_tensor_by_name(i) for i in input_names}
94-
output_tensors = {i: sess.graph.get_tensor_by_name(i) for i in output_names}
95-
graph_def = tf_optimize(input_tensors, output_tensors, graph_def)
106+
input_names = inputs_without_resource(sess, input_names)
107+
graph_def = tf_optimize(input_names, output_names, graph_def)
96108
return graph_def
97109

98110

@@ -101,7 +113,6 @@ def freeze_session(sess, input_names=None, output_names=None):
101113
output_node_names = [i.split(':')[:-1][0] for i in output_names]
102114
keep_var_names = [i.split(':')[:-1][0] for i in input_names]
103115
with sess.graph.as_default():
104-
# freeze_var_names = list(set(v.op.name for v in tf_global_variables()).difference(keep_var_names or []))
105116
output_node_names = output_node_names or []
106117
output_node_names += [v.op.name for v in tf_global_variables()]
107118
output_node_names += keep_var_names
@@ -135,16 +146,16 @@ def from_graphdef(model_path, input_names, output_names):
135146
with tf_gfile.GFile(model_path, 'rb') as f:
136147
graph_def.ParseFromString(f.read())
137148
tf.import_graph_def(graph_def, name='')
149+
input_names = inputs_without_resource(sess, input_names)
138150
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
139151
input_names = remove_redundant_inputs(frozen_graph, input_names)
140-
inputs = {i: sess.graph.get_tensor_by_name(i) for i in input_names}
141-
outputs = {i: sess.graph.get_tensor_by_name(i) for i in output_names}
142152

143153
tf_reset_default_graph()
144154
with tf_session() as sess:
145-
frozen_graph = tf_optimize(inputs, outputs, frozen_graph)
155+
input_names = inputs_without_resource(sess, input_names)
156+
frozen_graph = tf_optimize(input_names, output_names, frozen_graph)
146157
tf_reset_default_graph()
147-
return frozen_graph, inputs, outputs
158+
return frozen_graph, input_names, output_names
148159

149160

150161
def from_checkpoint(model_path, input_names, output_names):
@@ -156,23 +167,19 @@ def from_checkpoint(model_path, input_names, output_names):
156167
saver = tf_import_meta_graph(model_path, clear_devices=True)
157168
# restore from model_path minus the ".meta"
158169
saver.restore(sess, model_path[:-5])
170+
input_names = inputs_without_resource(sess, input_names)
159171
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
160172
input_names = remove_redundant_inputs(frozen_graph, input_names)
161-
inputs = {i: sess.graph.get_tensor_by_name(i) for i in input_names}
162-
outputs = {i: sess.graph.get_tensor_by_name(i) for i in output_names}
163173

164174
tf_reset_default_graph()
165175
with tf_session() as sess:
166-
frozen_graph = tf_optimize(inputs, outputs, frozen_graph)
176+
frozen_graph = tf_optimize(input_names, output_names, frozen_graph)
167177
tf_reset_default_graph()
168-
return frozen_graph, inputs, outputs
178+
return frozen_graph, input_names, output_names
169179

170180

171181
def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures):
172182
"""Load tensorflow graph from saved_model."""
173-
# make sure we start with clean default graph
174-
inputs = {}
175-
outputs = {}
176183

177184
imported = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)
178185
for k in imported.signature_def.keys():
@@ -189,24 +196,21 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures
189196
# TF1.12 changed the api
190197
get_signature_def = lambda meta_graph_def, k: meta_graph_def.signature_def[k]
191198

199+
input_names = []
200+
output_names = []
192201
for k in signatures:
193202
inputs_tensor_info = get_signature_def(imported, k).inputs
194203
for _, input_tensor in inputs_tensor_info.items():
195-
inputs[input_tensor.name] = sess.graph.get_tensor_by_name(input_tensor.name)
204+
input_names.append(input_tensor.name)
196205
outputs_tensor_info = get_signature_def(imported, k).outputs
197206
for _, output_tensor in outputs_tensor_info.items():
198-
outputs[output_tensor.name] = sess.graph.get_tensor_by_name(output_tensor.name)
199-
200-
frozen_graph = freeze_session(sess, input_names=list(inputs.keys()), output_names=list(outputs.keys()))
201-
return frozen_graph, inputs, outputs
207+
output_names.append(output_tensor.name)
208+
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
209+
return frozen_graph, input_names, output_names
202210

203211

204212
def _from_saved_model_v2(model_path, input_names, output_names, signatures):
205213
"""Load tensorflow graph from saved_model."""
206-
# make sure we start with clean default graph
207-
inputs = {}
208-
outputs = {}
209-
210214
imported = tf.saved_model.load(model_path) # pylint: disable=no-value-for-parameter
211215

212216
# f = meta_graph_def.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
@@ -217,34 +221,34 @@ def _from_saved_model_v2(model_path, input_names, output_names, signatures):
217221
signatures.append(k)
218222
for k in signatures:
219223
concrete_func = imported.signatures[k]
220-
inputs = {input_tensor.name: input_tensor for input_tensor in concrete_func.inputs}
221-
outputs = {output_tensor.name: output_tensor for output_tensor in concrete_func.outputs}
224+
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
225+
if input_tensor.dtype != tf.dtypes.resource]
226+
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
227+
if output_tensor.dtype != tf.dtypes.resource]
222228

223-
frozen_graph = from_function(concrete_func, list(inputs.keys()), list(outputs.keys()))
224-
return frozen_graph, inputs, outputs
229+
frozen_graph = from_function(concrete_func, input_names, output_names)
230+
return frozen_graph, input_names, output_names
225231

226232

227233
def from_saved_model(model_path, input_names, output_names, signatures=None):
228234
"""Load tensorflow graph from saved_model."""
229-
# make sure we start with clean default graph
230235
if signatures is None:
231236
signatures = []
232237
tf_reset_default_graph()
233238
if is_tf2():
234-
frozen_graph, inputs, outputs = \
239+
frozen_graph, input_names, output_names = \
235240
_from_saved_model_v2(model_path, input_names, output_names, signatures)
236-
inputs = {k: v for k, v in inputs.items() if v.dtype != tf.dtypes.resource}
237241
else:
238242
with tf_session() as sess:
239-
frozen_graph, inputs, outputs = \
243+
frozen_graph, input_names, output_names = \
240244
_from_saved_model_v1(sess, model_path, input_names, output_names, signatures)
241245

242246
if len(signatures) > 1:
243247
logger.warning("found multiple signatures %s in saved_model, pass --signature_def in command line",
244248
signatures)
245249

246250
tf_reset_default_graph()
247-
return frozen_graph, inputs, outputs
251+
return frozen_graph, input_names, output_names
248252

249253

250254
def from_keras(model_path, input_names, output_names):
@@ -261,34 +265,35 @@ def from_keras(model_path, input_names, output_names):
261265

262266
function = _saving_utils.trace_model_call(keras_model)
263267
concrete_func = function.get_concrete_function()
264-
inputs = {input_tensor.name: input_tensor for input_tensor in concrete_func.inputs if
265-
input_tensor.name in input_names}
266-
outputs = {output_tensor.name: output_tensor for output_tensor in concrete_func.outputs if
267-
output_tensor.name in output_names}
268-
frozen_graph = from_function(concrete_func, list(inputs.keys()), list(outputs.keys()))
268+
# allow to pass inputs and outputs from caller if we don't want all of them
269+
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
270+
if input_tensor.dtype != tf.dtypes.resource]
271+
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
272+
if output_tensor.dtype != tf.dtypes.resource]
273+
274+
frozen_graph = from_function(concrete_func, input_names, output_names)
269275
else:
270276
# Handles Keras when Eager mode is disabled.
271277
_keras.backend.clear_session()
272278
_keras.backend.set_learning_phase(False)
273279
keras_model = _keras.models.load_model(model_path, custom_objects)
280+
# allow to pass inputs and outputs from caller if we don't want all of them
281+
input_names = keras_model.inputs
282+
output_names = keras_model.outputs
274283
sess = _keras.backend.get_session()
284+
input_names = inputs_without_resource(sess, input_names)
275285
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
276-
inputs = {i: sess.graph.get_tensor_by_name(i) for i in keras_model.inputs}
277-
outputs = {i: sess.graph.get_tensor_by_name(i) for i in keras_model.outputs}
278286
tf_reset_default_graph()
279287
with tf_session() as sess:
280-
frozen_graph = tf_optimize(inputs, outputs, frozen_graph)
288+
frozen_graph = tf_optimize(input_names, output_names, frozen_graph)
281289
tf_reset_default_graph()
282-
return frozen_graph, inputs, outputs
290+
return frozen_graph, input_names, output_names
283291

284292

285-
def tf_optimize_grappler(input_tensors, output_tensors, graph_def, fold_constant=None):
293+
def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=None):
286294
from tensorflow.core.protobuf import meta_graph_pb2 as meta_graph_pb2, config_pb2, rewriter_config_pb2
287295
from tensorflow.python.grappler import tf_optimizer as tf_opt
288296

289-
# don't use resource type as input
290-
output_tensors = list(output_tensors)
291-
292297
config = config_pb2.ConfigProto()
293298
rewrite_options = config.graph_options.rewrite_options
294299
config.graph_options.infer_shapes = True
@@ -300,34 +305,27 @@ def tf_optimize_grappler(input_tensors, output_tensors, graph_def, fold_constant
300305
]
301306
meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph_def)
302307
fetch_collection = meta_graph_pb2.CollectionDef()
303-
for t in list(input_tensors) + output_tensors:
308+
for t in input_names + output_names:
304309
fetch_collection.node_list.value.append(t)
305310
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
306311
graph_def = tf_opt.OptimizeGraph(config, meta_graph)
307312
return graph_def
308313

309314

310-
def tf_optimize(input_tensors, output_tensors, graph_def, fold_constant=True):
315+
def tf_optimize(input_names, output_names, graph_def, fold_constant=True):
311316
"""Extract inference subgraph and optimize graph."""
312-
assert isinstance(input_tensors, dict)
313-
assert isinstance(output_tensors, dict)
314-
try:
315-
input_tensors = {
316-
name: tensor for name, tensor in input_tensors.items()
317-
if tensor.dtype != tf.dtypes.resource
318-
}
319-
except: # pylint: disable=bare-except
320-
pass
317+
assert isinstance(input_names, list)
318+
assert isinstance(output_names, list)
321319

322320
# TODO: is this needed ?
323-
needed_names = [utils.node_name(i) for i in input_tensors.keys()] + \
324-
[utils.node_name(i) for i in output_tensors.keys()]
321+
needed_names = [utils.node_name(i) for i in input_names] + \
322+
[utils.node_name(i) for i in output_names]
325323
graph_def = extract_sub_graph(graph_def, needed_names)
326324

327325
if fold_constant:
328326
want_grappler = is_tf2() or LooseVersion(tf.__version__) >= "1.15"
329327
if want_grappler:
330-
graph_def = tf_optimize_grappler(input_tensors, output_tensors, graph_def, fold_constant)
328+
graph_def = tf_optimize_grappler(input_names, output_names, graph_def, fold_constant)
331329
else:
332330
# the older transform path
333331
from tensorflow.tools.graph_transforms import TransformGraph # pylint: disable=redefined-outer-name
@@ -341,7 +339,7 @@ def tf_optimize(input_tensors, output_tensors, graph_def, fold_constant=True):
341339
"fold_batch_norms",
342340
"fold_old_batch_norms",
343341
])
344-
graph_def = TransformGraph(graph_def, input_tensors.keys(), output_tensors.keys(), transforms)
342+
graph_def = TransformGraph(graph_def, input_names, output_names, transforms)
345343

346344
return graph_def
347345

@@ -367,7 +365,6 @@ def is_function(g):
367365
return 'tensorflow.python.framework.func_graph.FuncGraph' in str(type(g))
368366
return False
369367

370-
371368
_FUNCTIONS = {}
372369

373370

@@ -387,7 +384,12 @@ def toposort(data):
387384
fdef = fdef.definition
388385
if input_shapes and len(fdef.signature.input_arg) < len(input_shapes):
389386
input_shapes = input_shapes[:len(fdef.signature.input_arg)]
390-
func = function_def_to_graph(fdef, input_shapes=input_shapes)
387+
try:
388+
func = function_def_to_graph(fdef, input_shapes=input_shapes)
389+
except: # pylint: disable=bare-except
390+
# if there is a missmatch between caller and function use the functions shape
391+
logger.warning("shape missmatch between caller and function: %s", k)
392+
func = function_def_to_graph(fdef)
391393
_FUNCTIONS[k] = func
392394
_, _, _, _, _, tfunctions = tflist_to_onnx(func, {})
393395
functions.update(tfunctions)

tf2onnx/tfonnx.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
361361
Return:
362362
onnx graph
363363
"""
364-
# TODO: remove verbose argument in future release
365364
if verbose:
366365
logger.warning("Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead.")
367366
del verbose
@@ -507,3 +506,9 @@ def compat_handler(ctx, node, **kwargs):
507506
"\tonnx unmapped: {}".format(op_cnt, attr_cnt, mapped_op, unmapped_op))
508507

509508
return g
509+
510+
511+
def tf_optimize(input_names, output_names, graph_def, fold_constant=True):
512+
"""optimize tensorflow graph. This is in tf_loader but some apps call this
513+
so we proxy into tf_loader to keep them working."""
514+
return tf2onnx.tf_loader.tf_optimize(input_names, output_names, graph_def, fold_constant)

0 commit comments

Comments
 (0)