Skip to content

Commit e2aa60c

Browse files
committed
allow shape override for inputs
1 parent 4a7e895 commit e2aa60c

File tree

5 files changed

+51
-18
lines changed

5 files changed

+51
-18
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Parameters:
6868
- inputs/outputs: Tensorflow graph's input/output names, which can be found with [summarize graph tool](#summarize_graph).
6969
- target: There are different onnx versions and workarounds for runtimes that can be set with ```--target TARGET```.
7070
- opset: by default we uses the newest opset installed with the onnx package (for example onnx-1.2.2 would have opset 7). By specifieing ```--opset``` the user can override the default to generate a graph with the desired opset. For example ```--opset 5``` would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.
71-
- custom-ops: the runtime may support custom ops that are not defined in onnx. A user can asked the converter to map to custom ops by listing them with the --custom-ops option. Tensorflow ops listed here will be mapped to a custom op of the same name as the tensorflow op but in the onnx domain ai.onnx.converters.tensorflow. For example: ```--custom-ops Print``` will insert a op ```Print``` in the onnx domain ```ai.onnx.converters.tensorflow``` into the graph. We also support a python api for custom ops documented later in this readme.
71+
- custom-ops: the runtime may support custom ops that are not defined in onnx. A user can asked the converter to map to custom ops by listing them with the --custom-ops option. Tensorflow ops listed here will be mapped to a custom op of the same name as the tensorflow op but in the onnx domain ai.onnx.converters.tensorflow. For example: ```--custom-ops Print``` will insert a op ```Print``` in the onnx domain ```ai.onnx.converters.tensorflow``` into the graph. We also support a python api for custom ops documented later in this readme.
7272

7373
Usage example (run following commands in tensorflow-onnx root directory):
7474
```
@@ -79,6 +79,8 @@ python -m tf2onnx.convert\
7979
--output tests/models/fc-layers/model.onnx\
8080
--verbose
8181
```
82+
Some models specify placeholders with unknown ranks which can not be mapped to onnx.
83+
In those cases one can add the shape behind the input name in ```[]```, for example ```--input X:0[1,28,28,3]```
8284

8385
## <a name="summarize_graph"></a>Tool to get Graph Inputs & Outputs
8486

tests/run_pretrained_models.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class Test(object):
103103

104104
def __init__(self, url, local, make_input, input_names, output_names,
105105
disabled=False, more_inputs=None, rtol=0.01, atol=0.,
106-
check_only_shape=False, model_type="frozen"):
106+
check_only_shape=False, model_type="frozen", force_input_shape=False):
107107
self.url = url
108108
self.make_input = make_input
109109
self.local = local
@@ -118,6 +118,7 @@ def __init__(self, url, local, make_input, input_names, output_names,
118118
self.tf_runtime = 0
119119
self.onnx_runtime = 0
120120
self.model_type = model_type
121+
self.force_input_shape = force_input_shape
121122

122123
def download_file(self):
123124
"""Download file from url."""
@@ -171,9 +172,9 @@ def run_tensorflow(self, sess, inputs):
171172
return result
172173

173174
@staticmethod
174-
def to_onnx(tf_graph, opset=None):
175+
def to_onnx(tf_graph, opset=None, shape_override=None):
175176
"""Convert graph to tensorflow."""
176-
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset)
177+
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset, shape_override=shape_override)
177178

178179
def run_caffe2(self, name, onnx_graph, inputs):
179180
"""Run test again caffe2 backend."""
@@ -253,6 +254,8 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
253254
"""Run complete test against backend."""
254255
print(name)
255256
self.perf = perf
257+
258+
# get the model
256259
if self.url:
257260
_, dir_name = self.download_file()
258261
model_path = os.path.join(dir_name, self.local)
@@ -270,6 +273,7 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
270273
tf.train.write_graph(frozen_graph, dir_name, "frozen.pb", as_text=False)
271274
model_path = os.path.join(dir_name, "frozen.pb")
272275

276+
# create the input data
273277
inputs = self.make_input(self.input_names)
274278
if self.more_inputs:
275279
for k, v in self.more_inputs.items():
@@ -278,8 +282,9 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
278282
graph_def = graph_pb2.GraphDef()
279283
with open(model_path, "rb") as f:
280284
graph_def.ParseFromString(f.read())
281-
graph_def = tf2onnx.tfonnx.tf_optimize(None, inputs, self.output_names, graph_def)
282285

286+
graph_def = tf2onnx.tfonnx.tf_optimize(None, inputs, self.output_names, graph_def)
287+
shape_override = {}
283288
g = tf.import_graph_def(graph_def, name='')
284289
with tf.Session(graph=g) as sess:
285290

@@ -290,12 +295,16 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
290295
if type != "float32":
291296
v = inputs[k]
292297
inputs[k] = v.astype(dtype)
298+
if self.force_input_shape:
299+
shape_override = self.input_names
293300

301+
# run the model with tensorflow
294302
tf_results = self.run_tensorflow(sess, inputs)
295303
onnx_graph = None
296304
print("\ttensorflow", "OK")
297305
try:
298-
onnx_graph = self.to_onnx(sess.graph, opset=opset)
306+
# convert model to onnx
307+
onnx_graph = self.to_onnx(sess.graph, opset=opset, shape_override=shape_override)
299308
print("\tto_onnx", "OK")
300309
if debug:
301310
onnx_graph.dump_graph()
@@ -362,7 +371,7 @@ def tests_from_yaml(fname):
362371
input_func = v.get("input_get")
363372
input_func = _INPUT_FUNC_MAPPING[input_func]
364373
kwargs = {}
365-
for kw in ["rtol", "atol", "disabled", "more_inputs", "check_only_shape", "model_type"]:
374+
for kw in ["rtol", "atol", "disabled", "more_inputs", "check_only_shape", "model_type", "force_input_shape"]:
366375
if v.get(kw) is not None:
367376
kwargs[kw] = v[kw]
368377

tests/run_pretrained_models.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ mobilenet_v2_1.4_224:
129129
url: https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz
130130
model: mobilenet_v2_1.4_224_frozen.pb
131131
input_get: get_beach
132+
force_input_shape: true
132133
inputs:
133134
"input:0": [1, 224, 224, 3]
134135
outputs:

tf2onnx/convert.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from __future__ import print_function
1010

1111
import argparse
12+
import re
1213
import sys
1314

1415
import onnx
@@ -36,8 +37,22 @@ def get_args():
3637
parser.add_argument("--verbose", help="verbose output", action="store_true")
3738
args = parser.parse_args()
3839

40+
args.shape_override = None
3941
if args.inputs:
40-
args.inputs = args.inputs.split(",")
42+
inputs = []
43+
shapes = {}
44+
# input takes in most cases the format name:0, where 0 is the output number
45+
# in some cases placeholders don't have a rank which onnx can't handle so we let uses override the shape
46+
# by appending the same, ie : [1,28,28,3]
47+
#
48+
pattern = r"(?:([\w:]+)(\[[\d,]+\])?),?"
49+
splits = re.split(pattern, args.inputs)
50+
for i in range(1, len(splits), 3):
51+
inputs.append(splits[i])
52+
if splits[i+1] is not None:
53+
shapes[splits[i]] = [int(n) for n in splits[i+1][1:-1].split(",")]
54+
args.inputs = inputs
55+
args.shape_override = shapes
4156
if args.outputs:
4257
args.outputs = args.outputs.split(",")
4358
if args.target:
@@ -85,7 +100,8 @@ def main():
85100
target=args.target,
86101
opset=args.opset,
87102
custom_op_handlers=custom_ops,
88-
extra_opset=extra_opset)
103+
extra_opset=extra_opset,
104+
shape_override=args.shape_override)
89105

90106
model_proto = g.make_model(
91107
"converted from {}".format(args.input), args.inputs, args.outputs,

tf2onnx/tfonnx.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
DEFAULT_TARGET = [TARGET_RS4, TARGET_CAFFE2]
3535

3636

37-
def tensorflow_to_onnx(graph):
37+
def tensorflow_to_onnx(graph, shape_override):
3838
"""
3939
Load tensorflow graph into an onnx graph with minimal rewrites so
4040
we can use the onnx graph as intermediate graph.
@@ -58,10 +58,12 @@ def tensorflow_to_onnx(graph):
5858
# create dict with output to shape mappings
5959
for node in ops:
6060
for out in node.outputs:
61-
try:
62-
shape = out.get_shape().as_list()
63-
except Exception as ex:
64-
shape = []
61+
shape = shape_override.get(out.name)
62+
if shape is None:
63+
try:
64+
shape = out.get_shape().as_list()
65+
except Exception as ex:
66+
shape = []
6567
dtypes[out.name] = utils.map_tf_dtype(out.dtype)
6668
output_shapes[out.name] = shape
6769

@@ -1306,9 +1308,9 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
13061308
onnx_node = func(g, node, node.name, args)
13071309
except Exception as ex:
13081310
type_, value_, traceback_ = sys.exc_info()
1309-
ex = traceback.format_exception(type_, value_, traceback_)
1311+
ex_ext = traceback.format_exception(type_, value_, traceback_)
13101312
if continue_on_error:
1311-
print(ex)
1313+
print(ex_ext)
13121314
onnx_nodes.append(node)
13131315
else:
13141316
raise ex
@@ -1336,7 +1338,8 @@ def tf_optimize(sess, inputs, outputs, graph_def):
13361338

13371339

13381340
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None,
1339-
opset=None, custom_op_handlers=None, custom_rewriter=None, extra_opset=None):
1341+
opset=None, custom_op_handlers=None, custom_rewriter=None,
1342+
extra_opset=None, shape_override=None):
13401343
"""Convert tensorflow graph to onnx graph.
13411344
Args:
13421345
tf_graph: tensorflow graph
@@ -1359,10 +1362,12 @@ def topological_sort(ops):
13591362
# if we continue on error, ignore graph cycles so we can report all missing ops
13601363
pass
13611364

1365+
if shape_override is None:
1366+
shape_override = {}
13621367
if target is None:
13631368
target = DEFAULT_TARGET
13641369

1365-
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(tf_graph)
1370+
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(tf_graph, shape_override)
13661371

13671372
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset)
13681373
ops = g.get_nodes()

0 commit comments

Comments
 (0)