Skip to content

Commit 717721a

Browse files
authored
Merge pull request #316 from zhijxu-MS/map_fn
add test cases of map_fn, fix bug of shape_inference
2 parents 90f2bbe + e29e418 commit 717721a

File tree

3 files changed

+72
-12
lines changed

3 files changed

+72
-12
lines changed

tests/test_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ def test_multinomial(self):
211211
self._run_test_case([_OUTPUT], {_INPUT: x_val}, check_value=False,
212212
check_shape=True, check_dtype=True)
213213

214-
215214
@unittest.skipIf(BACKEND in ["caffe2"], "not supported correctly in caffe2")
216215
@unittest.skipIf(*support_op_conversion_since(7, "multinomial"))
217216
def test_multinomial1(self):

tests/test_loops.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
1818

19+
1920
class LoopTests(Tf2OnnxBackendTestBase):
2021

2122
def test_simple_while_loop(self):
@@ -31,7 +32,6 @@ def test_simple_while_loop(self):
3132
output_names_with_port = ["output:0"]
3233
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
3334

34-
3535
def test_simple_while_loop_2(self):
3636
i = tf.placeholder(tf.int32, (), name="input_1")
3737
c = lambda i: tf.logical_and(tf.less(i, 10), tf.greater_equal(i, 3))
@@ -45,7 +45,6 @@ def test_simple_while_loop_2(self):
4545
output_names_with_port = ["output:0"]
4646
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
4747

48-
4948
def test_while_loop_with_ta_write(self):
5049
i = tf.placeholder(tf.int32, (), name="input_1")
5150
output_ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
@@ -68,7 +67,6 @@ def b(i, out_ta):
6867
output_names_with_port = ["output:0", "i:0"]
6968
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
7069

71-
7270
def test_while_loop_with_ta_read(self):
7371
i = tf.placeholder(tf.int32, (), name="input_1")
7472
inputs = tf.placeholder(tf.float32, (10,), name="input_2")
@@ -82,6 +80,7 @@ def test_while_loop_with_ta_read(self):
8280
c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0)
8381
res = tf.constant(0.)
8482
res2 = tf.constant(1.)
83+
8584
def b(i, res, res2):
8685
new_i = tf.add(i, 1)
8786
x = input_ta.read(i)
@@ -113,6 +112,7 @@ def test_while_loop_with_ta_read_reference_outer_input_directly(self):
113112
c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0)
114113
res = tf.constant(0.)
115114
res2 = tf.constant(1.)
115+
116116
def b(i, res, res2):
117117
new_i = tf.add(i, 1)
118118
x = input_ta.read(i)
@@ -132,7 +132,6 @@ def b(i, res, res2):
132132
output_names_with_port = ["i:0", "x:0", "y:0"]
133133
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
134134

135-
136135
def test_while_loop_with_ta_read_and_write(self):
137136
i = tf.placeholder(tf.int32, (), name="input_1")
138137
inputs = tf.placeholder(tf.float32, (10,), name="input_2")
@@ -160,5 +159,42 @@ def b(i, out_ta):
160159
output_names_with_port = ["i:0", "output_ta:0"]
161160
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
162161

162+
def test_map_fn(self):
163+
def fn0(elem):
164+
res = elem + elem * elem
165+
return res
166+
167+
def fn1(elem):
168+
res = elem[0] * elem[1] + elem[0]
169+
return res
170+
171+
x_val = 100 * np.random.random_sample([2, 10]).astype(np.float32)
172+
y_val = 100 * np.random.random_sample([2, 10]).astype(np.float32)
173+
174+
# test fn0
175+
x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_0")
176+
x_ = tf.identity(x)
177+
res_ = tf.map_fn(fn0, x_, dtype=tf.float32)
178+
_ = tf.identity(res_, name="output_0")
179+
feed_dict = {"input_0:0": x_val}
180+
input_names_with_port = ["input_0:0"]
181+
output_names_with_port = ["output_0:0"]
182+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5)
183+
tf.reset_default_graph()
184+
185+
# test fn1
186+
x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_0")
187+
y = tf.placeholder(tf.float32, shape=y_val.shape, name="input_1")
188+
x_ = tf.identity(x)
189+
y_ = tf.identity(y)
190+
res_ = tf.map_fn(fn1, (x_, y_), dtype=tf.float32)
191+
_ = tf.identity(res_, name="output_0")
192+
feed_dict = {"input_0:0": x_val, "input_1:0": y_val}
193+
input_names_with_port = ["input_0:0", "input_1:0"]
194+
output_names_with_port = ["output_0:0"]
195+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5)
196+
tf.reset_default_graph()
197+
198+
163199
if __name__ == '__main__':
164200
Tf2OnnxBackendTestBase.trigger(LoopTests)

tf2onnx/shape_inference.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from __future__ import unicode_literals
1111
import logging
1212
from onnx import onnx_pb
13+
from tf2onnx import utils
1314

1415
# pylint: disable=logging-not-lazy,missing-docstring,consider-swap-variables
1516

@@ -91,10 +92,10 @@ def infer_shape_for_node(g, node):
9192
shape_attr = node.get_attr("shape")
9293
new_shape = None
9394
if shape_attr.type == onnx_pb.TensorProto.INT32:
94-
new_shape = shape_attr.ints
95+
new_shape = list(shape_attr.ints)
9596
elif shape_attr.type == onnx_pb.TensorProto.FLOAT:
9697
# for scalar placeholder, it's type is float
97-
val = shape_attr.floats
98+
val = list(shape_attr.floats)
9899
if val:
99100
raise ValueError("placeholder shape has floats value, and not scalar value")
100101
else:
@@ -243,13 +244,24 @@ def infer_output_shapes_with_partial_inputs(g, node):
243244
return True
244245

245246
if node.type == "TensorArrayGatherV3":
246-
# TensorArrayGatherV3's output: all of the elements in the TensorArray,
247-
# concatenated along a new axis (the new dimension 0)
248-
flow_in_node = node.inputs[2]
249-
if flow_in_node.type != "Exit":
247+
# TensorArrayGatherV3's output: all of the elem in the TensorArray,
248+
# concatenated along a new axis (the new dimension 0), so shape of TensorArray should be found first.
249+
# And TensorArrayWrite will write elem to TensorArray, so shape of TensorArray can be got from TensorArrayWrite
250+
# so the process is: first find TensorArrayWrite and then get TensorArray's shape,
251+
# and finally add one dim to the shape is shape of TensorArrayGather
252+
253+
handle_node = node.inputs[0]
254+
if handle_node.type != "TensorArrayV3":
250255
return False
251256

252-
shape = g.get_shape(flow_in_node.output[0])
257+
# find TensorArrayWrite
258+
tensor_array_write_node = _find_tensorarray_write(g, handle_node)
259+
if not tensor_array_write_node:
260+
return False
261+
# get TensorArray shape from input tensor of the found TensorArrayWrite node
262+
value_node = tensor_array_write_node.inputs[2]
263+
shape = g.get_shape(value_node.output[0])
264+
# update TensorArray's shape info
253265
if shape is not None:
254266
new_shape = [-1] + shape
255267
g.set_shape(node.output[0], new_shape)
@@ -350,3 +362,16 @@ def broadcast_shape_inference(shape_0, shape_1):
350362
return None
351363
i -= 1
352364
return new_shape
365+
366+
367+
def _find_tensorarray_write(graph, node):
368+
utils.make_sure(node.type == "TensorArrayV3", "node should be tensorarray")
369+
370+
tensor_array_consumers = graph.find_output_consumers(node.output[0])
371+
for i in tensor_array_consumers:
372+
if i.type == "Enter":
373+
consumer_nodes = graph.find_output_consumers(i.output[0])
374+
for j in consumer_nodes:
375+
if j.type == "TensorArrayWriteV3":
376+
return j
377+
return None

0 commit comments

Comments
 (0)