Skip to content

Commit cba8f8d

Browse files
Use structured inputs by default from CLI (#1534)
* Remove captured inputs when using CLI Signed-off-by: Tom Wildenhain <[email protected]> * Switch to use structured input names by default Signed-off-by: Tom Wildenhain <[email protected]> * Bugfixes Signed-off-by: Tom Wildenhain <[email protected]> * Update tutorials Signed-off-by: Tom Wildenhain <[email protected]> * Remove logging lines Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 5e0b4f3 commit cba8f8d

File tree

7 files changed

+33
-30
lines changed

7 files changed

+33
-30
lines changed

examples/benchmark_tfmodel_ort.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def measure_time(fct, imgs):
3838
# Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1
3939
# python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12
4040
ort = ort.InferenceSession('esrgan-tf2.onnx')
41-
fct_ort = lambda img: ort.run(None, {'input_0:0': img})
41+
fct_ort = lambda img: ort.run(None, {'input_0': img})
4242
results_ort, duration_ort = measure_time(fct_ort, imgs)
4343
print(len(imgs), duration_ort)
4444

examples/end2end_tfhub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
########################################
6363
# Runs onnxruntime.
6464
session = InferenceSession("efficientnetb0clas.onnx")
65-
got = session.run(None, {'input_1:0': input})
65+
got = session.run(None, {'input_1': input})
6666
print(got[0])
6767

6868
########################################
@@ -73,5 +73,5 @@
7373
# Measures processing time.
7474
print('tf:', timeit.timeit('model.predict(input)',
7575
number=10, globals=globals()))
76-
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
76+
print('ort:', timeit.timeit("session.run(None, {'input_1': input})",
7777
number=10, globals=globals()))

examples/end2end_tfkeras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
########################################
5858
# Runs onnxruntime.
5959
session = InferenceSession("simple_rnn.onnx")
60-
got = session.run(None, {'input_1:0': input})
60+
got = session.run(None, {'input_1': input})
6161
print(got[0])
6262

6363
########################################
@@ -68,5 +68,5 @@
6868
# Measures processing time.
6969
print('tf:', timeit.timeit('model.predict(input)',
7070
number=100, globals=globals()))
71-
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
71+
print('ort:', timeit.timeit("session.run(None, {'input_1': input})",
7272
number=100, globals=globals()))

examples/getting_started.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def f(a, b):
5858

5959
print("ORT result")
6060
sess = ort.InferenceSession("model.onnx")
61-
res = sess.run(None, {'dense_input:0': x_val})
61+
res = sess.run(None, {'dense_input': x_val})
6262
print(res[0])
6363

6464
print("Conversion succeeded")

tests/run_pretrained_models.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
375375
initialized_tables = {}
376376
outputs = self.output_names
377377
tflite_path = None
378-
to_rename = None
378+
to_rename = {}
379379
if self.model_type in ["checkpoint"]:
380380
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
381381
elif self.model_type in ["saved_model"]:
@@ -438,7 +438,7 @@ def run_tflite():
438438
inputs = {}
439439
for k in input_names:
440440
v = self.input_names[k]
441-
inputs[k.split(":")[0]] = tf.constant(self.make_input(v))
441+
inputs[to_rename.get(k, k)] = tf.constant(self.make_input(v))
442442
tf_func = tf.function(concrete_func)
443443
logger.info("Running TF")
444444
tf_results_d = tf_func(**inputs)
@@ -553,11 +553,10 @@ def run_tflite():
553553
try:
554554
onnx_results = None
555555
if backend == "onnxruntime":
556-
if to_rename is None:
557-
struc_outputs = self.output_names
558-
else:
559-
struc_outputs = [to_rename.get(k, k) for k in self.output_names]
560-
onnx_results = self.run_onnxruntime(name, model_proto, inputs, struc_outputs, external_tensor_storage)
556+
struc_outputs = [to_rename.get(k, k) for k in self.output_names]
557+
struc_inputs = {to_rename.get(k, k): v for k, v in inputs.items()}
558+
onnx_results = self.run_onnxruntime(
559+
name, model_proto, struc_inputs, struc_outputs, external_tensor_storage)
561560
else:
562561
raise ValueError("unknown backend")
563562
logger.info("Run_ONNX OK")

tf2onnx/convert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def get_args():
6565
"change into Identity ops using their default value")
6666
parser.add_argument("--rename-inputs", help="input names to use in final model (optional)")
6767
parser.add_argument("--rename-outputs", help="output names to use in final model (optional)")
68+
parser.add_argument("--use-graph-names", help="(saved model only) skip renaming io using signature names",
69+
action="store_true")
6870
parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
6971
parser.add_argument("--dequantize", help="Remove quantization from model. Only supported for tflite currently.",
7072
action="store_true")
@@ -212,7 +214,8 @@ def main():
212214
if args.saved_model:
213215
graph_def, inputs, outputs, initialized_tables, tensors_to_rename = tf_loader.from_saved_model(
214216
args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function,
215-
args.large_model, return_initialized_tables=True, return_tensors_to_rename=True)
217+
args.large_model, return_initialized_tables=True, return_tensors_to_rename=True,
218+
use_graph_names=args.use_graph_names)
216219
model_path = args.saved_model
217220
if args.keras:
218221
graph_def, inputs, outputs = tf_loader.from_keras(

tf2onnx/tf_loader.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def from_checkpoint(model_path, input_names, output_names):
310310
return frozen_graph, input_names, output_names
311311

312312

313-
def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signature_names):
313+
def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signature_names, use_graph_names):
314314
"""Load tensorflow graph from saved_model."""
315315

316316
wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
@@ -345,22 +345,25 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
345345
# TF1.12 changed the api
346346
get_signature_def = lambda meta_graph_def, k: meta_graph_def.signature_def[k]
347347

348+
tensors_to_rename = {}
348349
if input_names is None:
349350
input_names = []
350351
for k in signatures:
351352
inputs_tensor_info = get_signature_def(imported, k).inputs
352-
for _, input_tensor in inputs_tensor_info.items():
353+
for structured_name, input_tensor in inputs_tensor_info.items():
353354
if input_tensor.name not in input_names:
354355
input_names.append(input_tensor.name)
355-
tensors_to_rename = {}
356+
if not use_graph_names:
357+
tensors_to_rename[input_tensor.name] = structured_name
356358
if output_names is None:
357359
output_names = []
358360
for k in signatures:
359361
outputs_tensor_info = get_signature_def(imported, k).outputs
360362
for structured_name, output_tensor in outputs_tensor_info.items():
361363
if output_tensor.name not in output_names:
362364
output_names.append(output_tensor.name)
363-
tensors_to_rename[output_tensor.name] = structured_name
365+
if not use_graph_names:
366+
tensors_to_rename[output_tensor.name] = structured_name
364367
frozen_graph, initialized_tables = \
365368
freeze_session(sess, input_names=input_names, output_names=output_names, get_tables=True)
366369
return frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename
@@ -447,7 +450,7 @@ def _restore_captured_resources(concrete_func, graph_captures_copy, func_capture
447450

448451

449452
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def,
450-
concrete_function_index, large_model):
453+
concrete_function_index, large_model, use_graph_names):
451454
"""Load tensorflow graph from saved_model."""
452455

453456
wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
@@ -495,18 +498,16 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
495498
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
496499
captured_inputs = [t_name.name for _, t_name in graph_captures.values()]
497500
inputs = [inp for inp in inputs if inp not in captured_inputs]
498-
if concrete_func.structured_input_signature is not None:
499-
args, kwargs = concrete_func.structured_input_signature
500-
structured_inputs = [t.name for t in args if isinstance(t, tf.TensorSpec)] + sorted(kwargs.keys())
501-
structured_inputs = set(inp + ":0" for inp in structured_inputs)
502-
if any(inp in structured_inputs for inp in inputs):
503-
inputs = [inp for inp in inputs if inp in structured_inputs]
501+
if concrete_func.structured_input_signature is not None and not use_graph_names:
502+
flat_structured_inp = tf.nest.flatten(concrete_func.structured_input_signature)
503+
structured_inputs = [t.name for t in flat_structured_inp if isinstance(t, tf.TensorSpec)]
504+
tensors_to_rename.update(zip(inputs, structured_inputs))
504505
else:
505506
inputs = input_names
506507

507508
if output_names is None:
508509
outputs = [tensor.name for tensor in concrete_func.outputs if tensor.dtype != tf.dtypes.resource]
509-
if isinstance(concrete_func.structured_outputs, dict):
510+
if isinstance(concrete_func.structured_outputs, dict) and not use_graph_names:
510511
# outputs are sorted, sort structured_outputs the same way
511512
structured_outputs = sorted(concrete_func.structured_outputs.keys())
512513
tensors_to_rename.update(zip(outputs, structured_outputs))
@@ -515,7 +516,6 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
515516
logger.info("Output names: %r", outputs)
516517
else:
517518
outputs = output_names
518-
logger.info("Outputs not left as None; will use provided names not structured output names.")
519519

520520
frozen_graph, initialized_tables = from_trackable(imported, concrete_func, inputs, outputs, large_model)
521521

@@ -524,7 +524,8 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
524524

525525
def from_saved_model(model_path, input_names, output_names, tag=None,
526526
signatures=None, concrete_function=None, large_model=False,
527-
return_concrete_func=False, return_initialized_tables=False, return_tensors_to_rename=False):
527+
return_concrete_func=False, return_initialized_tables=False,
528+
return_tensors_to_rename=False, use_graph_names=False):
528529
"""Load tensorflow graph from saved_model."""
529530
if signatures is None:
530531
signatures = []
@@ -533,7 +534,7 @@ def from_saved_model(model_path, input_names, output_names, tag=None,
533534
if is_tf2():
534535
frozen_graph, input_names, output_names, concrete_func, imported, initialized_tables, tensors_to_rename = \
535536
_from_saved_model_v2(model_path, input_names, output_names,
536-
tag, signatures, concrete_function, large_model)
537+
tag, signatures, concrete_function, large_model, use_graph_names)
537538
result = [frozen_graph, input_names, output_names]
538539
if return_concrete_func:
539540
result += [concrete_func, imported]
@@ -544,7 +545,7 @@ def from_saved_model(model_path, input_names, output_names, tag=None,
544545
else:
545546
with tf_session() as sess:
546547
frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename = \
547-
_from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures)
548+
_from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures, use_graph_names)
548549
result = [frozen_graph, input_names, output_names]
549550
if return_initialized_tables:
550551
result += [initialized_tables]

0 commit comments

Comments
 (0)