Skip to content

Commit 51cb0de

Browse files
authored
check model's input and output naming convention (#127)
* check model's input and output naming convention * use isdigit() instead of isnumeric() * merge the loops and handle the all underscore name cases * print out names that are not compliant * minor code tuning * use tensorflow v1.9.0 because tensorflow v1.10.0 requires protobuf v3.6.0, which windows platform doesn't support yet
1 parent 727cc8f commit 51cb0de

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

onnxmltools/convert/common/_topology.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# --------------------------------------------------------------------------
66

77
import re
8+
import warnings
89
from distutils.version import StrictVersion
910
from ...proto import onnx
1011
from ...proto import helper
@@ -649,17 +650,31 @@ def convert_topology(topology, model_name, doc_string, targeted_onnx):
649650
other_outputs[variable.raw_name] = variable
650651

651652
# Add roots the graph according to their order in the original model
653+
invalid_name = []
652654
for name in topology.raw_model.input_names:
655+
# Check input naming convention
656+
input_name = name.replace('_', '')
657+
if input_name and (input_name[0].isdigit() or (not input_name.isalnum())):
658+
invalid_name.append(name)
653659
if name in tensor_inputs:
654660
container.add_input(tensor_inputs[name])
661+
if invalid_name:
662+
warnings.warn('Some input names are not compliant with ONNX naming convention: %s' % invalid_name)
655663
for name in topology.raw_model.input_names:
656664
if name in other_inputs:
657665
container.add_input(other_inputs[name])
658666

659667
# Add leaves the graph according to their order in the original model
668+
invalid_name = []
660669
for name in topology.raw_model.output_names:
670+
# Check output naming convention
671+
output_name = name.replace('_', '')
672+
if output_name and (output_name[0].isdigit() or (not output_name.isalnum())):
673+
invalid_name.append(name)
661674
if name in tensor_outputs:
662675
container.add_output(tensor_outputs[name])
676+
if invalid_name:
677+
warnings.warn('Some output names are not compliant with ONNX naming convention: %s' % invalid_name)
663678
for name in topology.raw_model.output_names:
664679
if name in other_outputs:
665680
container.add_output(other_outputs[name])

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
numpy
22
protobuf
33
codecov
4-
tensorflow
4+
tensorflow==1.9.0
55
keras==2.1.6
66
coremltools
77
pandas

0 commit comments

Comments
 (0)