Skip to content

Commit 571f977

Browse files
committed
code refactor according to reviewer's comments
1 parent 1db2c68 commit 571f977

File tree

3 files changed

+57
-42
lines changed

3 files changed

+57
-42
lines changed

tests/test_backend.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -201,33 +201,6 @@ def test_atrig_ops(self):
201201
_ = tf.identity(op_, name=_TFOUTPUT)
202202
self._run_test_case([_OUTPUT], {_INPUT: x_val})
203203

204-
def test_map_fn(self):
205-
def fn0(elem):
206-
res = elem + elem * elem
207-
return res
208-
209-
def fn1(elem):
210-
res = elem[0] * elem[1] + elem[0]
211-
return res
212-
213-
x_val = 100 * np.random.random_sample([2, 10]).astype(np.float32)
214-
y_val = 100 * np.random.random_sample([2, 10]).astype(np.float32)
215-
216-
# test fn0
217-
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
218-
res_ = tf.map_fn(fn0, x, dtype=tf.float32)
219-
_ = tf.identity(res_, name=_TFOUTPUT1)
220-
self._run_test_case([_OUTPUT1], {_INPUT: x_val}, rtol=0)
221-
tf.reset_default_graph()
222-
223-
# test fn1
224-
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
225-
y = tf.placeholder(tf.float32, shape=y_val.shape, name=_TFINPUT1)
226-
res_ = tf.map_fn(fn1, (x, y), dtype=tf.float32)
227-
_ = tf.identity(res_, name=_TFOUTPUT1)
228-
self._run_test_case([_OUTPUT1], {_INPUT: x_val, _INPUT1: y_val}, rtol=0)
229-
tf.reset_default_graph()
230-
231204
@unittest.skipIf(BACKEND in ["caffe2"], "not supported correctly in caffe2")
232205
@unittest.skipIf(*support_op_conversion_since(7, "multinomial"))
233206
def test_multinomial(self):

tests/test_loops.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
1818

19+
1920
class LoopTests(Tf2OnnxBackendTestBase):
2021

2122
def test_simple_while_loop(self):
@@ -31,7 +32,6 @@ def test_simple_while_loop(self):
3132
output_names_with_port = ["output:0"]
3233
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
3334

34-
3535
def test_simple_while_loop_2(self):
3636
i = tf.placeholder(tf.int32, (), name="input_1")
3737
c = lambda i: tf.logical_and(tf.less(i, 10), tf.greater_equal(i, 3))
@@ -45,7 +45,6 @@ def test_simple_while_loop_2(self):
4545
output_names_with_port = ["output:0"]
4646
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
4747

48-
4948
def test_while_loop_with_ta_write(self):
5049
i = tf.placeholder(tf.int32, (), name="input_1")
5150
output_ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
@@ -68,7 +67,6 @@ def b(i, out_ta):
6867
output_names_with_port = ["output:0", "i:0"]
6968
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
7069

71-
7270
def test_while_loop_with_ta_read(self):
7371
i = tf.placeholder(tf.int32, (), name="input_1")
7472
inputs = tf.placeholder(tf.float32, (10,), name="input_2")
@@ -82,6 +80,7 @@ def test_while_loop_with_ta_read(self):
8280
c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0)
8381
res = tf.constant(0.)
8482
res2 = tf.constant(1.)
83+
8584
def b(i, res, res2):
8685
new_i = tf.add(i, 1)
8786
x = input_ta.read(i)
@@ -113,6 +112,7 @@ def test_while_loop_with_ta_read_reference_outer_input_directly(self):
113112
c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0)
114113
res = tf.constant(0.)
115114
res2 = tf.constant(1.)
115+
116116
def b(i, res, res2):
117117
new_i = tf.add(i, 1)
118118
x = input_ta.read(i)
@@ -132,7 +132,6 @@ def b(i, res, res2):
132132
output_names_with_port = ["i:0", "x:0", "y:0"]
133133
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
134134

135-
136135
def test_while_loop_with_ta_read_and_write(self):
137136
i = tf.placeholder(tf.int32, (), name="input_1")
138137
inputs = tf.placeholder(tf.float32, (10,), name="input_2")
@@ -160,5 +159,42 @@ def b(i, out_ta):
160159
output_names_with_port = ["i:0", "output_ta:0"]
161160
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
162161

162+
def test_map_fn(self):
163+
def fn0(elem):
164+
res = elem + elem * elem
165+
return res
166+
167+
def fn1(elem):
168+
res = elem[0] * elem[1] + elem[0]
169+
return res
170+
171+
x_val = 100 * np.random.random_sample([2, 10]).astype(np.float32)
172+
y_val = 100 * np.random.random_sample([2, 10]).astype(np.float32)
173+
174+
# test fn0
175+
x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_0")
176+
x_ = tf.identity(x)
177+
res_ = tf.map_fn(fn0, x_, dtype=tf.float32)
178+
_ = tf.identity(res_, name="output_0")
179+
feed_dict = {"input_0:0": x_val}
180+
input_names_with_port = ["input_0:0"]
181+
output_names_with_port = ["output_0:0"]
182+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5)
183+
tf.reset_default_graph()
184+
185+
# test fn1
186+
x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_0")
187+
y = tf.placeholder(tf.float32, shape=y_val.shape, name="input_1")
188+
x_ = tf.identity(x)
189+
y_ = tf.identity(y)
190+
res_ = tf.map_fn(fn1, (x_, y_), dtype=tf.float32)
191+
_ = tf.identity(res_, name="output_0")
192+
feed_dict = {"input_0:0": x_val, "input_1:0": y_val}
193+
input_names_with_port = ["input_0:0", "input_1:0"]
194+
output_names_with_port = ["output_0:0"]
195+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5)
196+
tf.reset_default_graph()
197+
198+
163199
if __name__ == '__main__':
164200
Tf2OnnxBackendTestBase.trigger(LoopTests)

tf2onnx/shape_inference.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from __future__ import unicode_literals
1111
import logging
1212
from onnx import onnx_pb
13+
from tf2onnx import utils
1314

1415
# pylint: disable=logging-not-lazy,missing-docstring,consider-swap-variables
1516

@@ -232,17 +233,9 @@ def infer_output_shapes_with_partial_inputs(g, node):
232233
return False
233234

234235
# find TensorArrayWrite
235-
tensor_array_consumers = g.find_output_consumers(handle_node.output[0])
236-
tensor_array_write_found = False
237-
for i in tensor_array_consumers:
238-
if tensor_array_write_found:
239-
break
240-
consumer_nodes = g.find_output_consumers(i.output[0])
241-
for j in consumer_nodes:
242-
if i.type == "Enter" and j.type == "TensorArrayWriteV3":
243-
tensor_array_write_node = j
244-
tensor_array_write_found = True
245-
break
236+
tensor_array_write_node = _find_tensorarray_write(g, handle_node)
237+
if not tensor_array_write_node:
238+
return False
246239
# get TensorArray shape from input tensor of the found TensorArrayWrite node
247240
value_node = tensor_array_write_node.inputs[2]
248241
shape = g.get_shape(value_node.output[0])
@@ -347,3 +340,16 @@ def broadcast_shape_inference(shape_0, shape_1):
347340
return None
348341
i -= 1
349342
return new_shape
343+
344+
345+
def _find_tensorarray_write(graph, node):
346+
utils.make_sure(node.type == "TensorArrayV3", "node should be tensorarray")
347+
348+
tensor_array_consumers = graph.find_output_consumers(node.output[0])
349+
for i in tensor_array_consumers:
350+
if i.type == "Enter":
351+
consumer_nodes = graph.find_output_consumers(i.output[0])
352+
for j in consumer_nodes:
353+
if j.type == "TensorArrayWriteV3":
354+
return j
355+
return None

0 commit comments

Comments
 (0)