Skip to content

Commit fa230d6

Browse files
authored
Merge pull request #102 from onnx/gs/lstm
fix for bug 101
2 parents 0c67734 + ef652ff commit fa230d6

File tree

3 files changed

+29
-14
lines changed

3 files changed

+29
-14
lines changed

tests/test_internals.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@ def test_match_flipped(self):
163163
match_results = list(matcher.match_ops(ops))
164164
self.assertEqual(1, len(match_results))
165165

166+
def test_cmdarg_parse(self):
167+
arg = "input/V-1_2:0,input/X:0[1,2,3],Y:1[4,5],Z:3,A:1,B"
168+
expected_inputs = ['input/V-1_2:0', 'input/X:0', 'Y:1', 'Z:3', 'A:1', 'B']
169+
expected_shape = {'Y:1': [4, 5], 'input/X:0': [1, 2, 3]}
170+
inputs, shape_override = tf2onnx.utils.split_nodename_and_shape(arg)
171+
self.assertEqual(expected_inputs, inputs)
172+
self.assertEqual(expected_shape, shape_override)
173+
166174

167175
if __name__ == '__main__':
168176
unittest.main()

tf2onnx/convert.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,7 @@ def get_args():
3939

4040
args.shape_override = None
4141
if args.inputs:
42-
inputs = []
43-
shapes = {}
44-
# input takes in most cases the format name:0, where 0 is the output number
45-
# in some cases placeholders don't have a rank which onnx can't handle so we let uses override the shape
46-
# by appending the same, ie : [1,28,28,3]
47-
#
48-
pattern = r"(?:([\w:]+)(\[[\d,]+\])?),?"
49-
splits = re.split(pattern, args.inputs)
50-
for i in range(1, len(splits), 3):
51-
inputs.append(splits[i])
52-
if splits[i+1] is not None:
53-
shapes[splits[i]] = [int(n) for n in splits[i+1][1:-1].split(",")]
54-
args.inputs = inputs
55-
args.shape_override = shapes
42+
args.inputs, args.shape_override = tf2onnx.utils.split_nodename_and_shape(args.inputs)
5643
if args.outputs:
5744
args.outputs = args.outputs.split(",")
5845
if args.target:

tf2onnx/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from __future__ import division
99
from __future__ import print_function
1010

11+
import re
1112
import numpy as np
1213
import tensorflow as tf
1314
from onnx import helper, onnx_pb
1415
from tensorflow.core.framework import types_pb2, tensor_pb2
1516

17+
1618
#
1719
# mapping dtypes from tensorflow to onnx
1820
#
@@ -94,6 +96,24 @@ def make_name(name):
9496
return "{}__{}".format(name, INTERNAL_NAME)
9597

9698

99+
def split_nodename_and_shape(name):
100+
# pattern for a node name
101+
inputs = []
102+
shapes = {}
103+
# input takes in most cases the format name:0, where 0 is the output number
104+
# in some cases placeholders don't have a rank which onnx can't handle so we let uses override the shape
105+
# by appending the same, ie : [1,28,28,3]
106+
name_pattern = r"(?:([\w\d/\-_:]+)(\[[\d,]+\])?),?"
107+
splits = re.split(name_pattern, name)
108+
for i in range(1, len(splits), 3):
109+
inputs.append(splits[i])
110+
if splits[i + 1] is not None:
111+
shapes[splits[i]] = [int(n) for n in splits[i + 1][1:-1].split(",")]
112+
if len(shapes) == 0:
113+
shapes = None
114+
return inputs, shapes
115+
116+
97117
def tf_to_onnx_tensor(tensor, name=""):
98118
"""Convert tensorflow tensor to onnx tensor."""
99119
new_type = TF_TO_ONNX_DTYPE[tensor.dtype]

0 commit comments

Comments
 (0)