Skip to content

Commit acb92b1

Browse files
authored
Merge pull request #491 from lucienwang1009/tf_infer_shape
infer shape for tensorflow graph
2 parents 73ac942 + aab7878 commit acb92b1

16 files changed

+783
-271
lines changed

tests/common.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
class TestConfig(object):
4444
def __init__(self):
4545
self.platform = sys.platform
46-
self.tf_version = self._get_tf_version()
46+
self.tf_version = utils.get_tf_version()
4747
self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", constants.PREFERRED_OPSET))
4848
self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',')
4949
self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
@@ -67,10 +67,6 @@ def is_caffe2_backend(self):
6767
def is_debug_mode(self):
6868
return utils.is_debug_mode()
6969

70-
def _get_tf_version(self):
71-
import tensorflow as tf
72-
return LooseVersion(tf.__version__)
73-
7470
def _get_backend_version(self):
7571
version = None
7672
if self.backend == "onnxruntime":

tests/test_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,7 +1271,7 @@ def test_randomuniform_int(self):
12711271
def test_randomuniform_dyn_shape(self):
12721272
# test for dynamic shape coming from a shape op
12731273
x_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
1274-
x = tf.placeholder(x_val.dtype, name=_TFINPUT)
1274+
x = tf.placeholder(x_val.dtype, [None, 3], name=_TFINPUT)
12751275
x_ = tf.stack([x, x])
12761276
x_ = tf.identity(x_)
12771277
x_ = tf.shape(x_, name="shape")

tests/test_loops.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tensorflow as tf
1313

1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import unittest_main
15+
from common import unittest_main, check_tf_min_version
1616

1717

1818
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -196,6 +196,24 @@ def fn1(elem):
196196
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5)
197197
tf.reset_default_graph()
198198

199+
@check_tf_min_version("1.9")
200+
def test_simple_while_loop_var_shape(self):
201+
# test for while_loop with variant shape variables
202+
# may not meet ONNX Loop spec
203+
i = tf.placeholder(tf.int32, (1), name="input_1")
204+
const = tf.constant(np.array([2], dtype=np.int32))
205+
206+
c = lambda i: tf.reduce_all(tf.shape(i) < 10)
207+
b = lambda i: tf.concat([i, const], 0)
208+
r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])])
209+
210+
_ = tf.identity(r, name="output")
211+
input_names_with_port = ["input_1:0"]
212+
feed_dict = {"input_1:0": np.array([0], dtype=np.int32)}
213+
214+
output_names_with_port = ["output:0"]
215+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
216+
199217

200218
if __name__ == '__main__':
201219
unittest_main()

tests/test_onnx_shape_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
INPUT3 = "input3"
2222

2323

24-
class ShapeInferenceTests(Tf2OnnxBackendTestBase):
24+
class ONNXShapeInferenceTests(Tf2OnnxBackendTestBase):
2525
"""
2626
Test shape inference, it's just a subset of all cases that can be inferred shape.
2727
For more information, please refer to onnx shape inference test:

tests/test_tf_shape_inference.py

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""Unit Tests for Tensorflow shape inference."""
5+
6+
from __future__ import division
7+
from __future__ import print_function
8+
from __future__ import unicode_literals
9+
10+
import os
11+
import numpy as np
12+
import tensorflow as tf
13+
14+
from tensorflow.python.ops import variables as variables_lib
15+
from tensorflow.python.ops import init_ops
16+
17+
from backend_test_base import Tf2OnnxBackendTestBase
18+
from common import * # pylint: disable=wildcard-import, unused-wildcard-import
19+
from tf2onnx import utils
20+
from tf2onnx.tfonnx import tf_optimize
21+
from tf2onnx.shape_inference import infer_shape_for_graph
22+
23+
# pylint: disable=missing-docstring
24+
25+
26+
class TFShapeInferenceTests(Tf2OnnxBackendTestBase):
27+
def _run_test_case(self, input_names_with_port, output_names_with_port):
28+
graph_def = None
29+
with tf.Session() as sess:
30+
# freeze graph
31+
origin_graph = sess.graph
32+
variables_lib.global_variables_initializer().run()
33+
output_name_without_port = [n.split(':')[0] for n in output_names_with_port]
34+
graph_def = tf.graph_util.convert_variables_to_constants(
35+
sess, sess.graph_def,
36+
output_name_without_port
37+
)
38+
39+
tf.reset_default_graph()
40+
tf.import_graph_def(graph_def, name='')
41+
42+
# optimize graph
43+
graph_def = tf_optimize(input_names_with_port, output_names_with_port,
44+
sess.graph_def, True)
45+
46+
with tf.Session() as sess:
47+
if self.config.is_debug_mode:
48+
if not os.path.exists(self.test_data_directory):
49+
os.makedirs(self.test_data_directory)
50+
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
51+
utils.save_protobuf(model_path, graph_def)
52+
self.logger.debug("created file %s", model_path)
53+
54+
tf.reset_default_graph()
55+
tf.import_graph_def(graph_def, name='')
56+
57+
with tf.Session() as sess:
58+
inferred_graph = infer_shape_for_graph(sess.graph)
59+
# compare each operation
60+
for op in origin_graph.get_operations():
61+
inferred_op = None
62+
try:
63+
inferred_op = inferred_graph.get_operation_by_name(op.name)
64+
except KeyError:
65+
continue
66+
self._compare_shape_for_op(op, inferred_op)
67+
68+
def _compare_shape_for_op(self, op1, op2):
69+
"""Align outputs of op2 to op1."""
70+
for out1, out2 in zip(op1.outputs, op2.outputs):
71+
expected_shape = utils.get_tf_tensor_shape(out1)
72+
if out1 is not None:
73+
actual_shape = utils.get_tf_tensor_shape(out2)
74+
self.assertTrue(utils.are_shapes_compatible(expected_shape, actual_shape))
75+
76+
def test_while_loop_with_ta_read_and_write(self):
77+
i = tf.placeholder(tf.int32, (), name="input_1")
78+
inputs = tf.placeholder(tf.float32, (10,), name="input_2")
79+
80+
inputs_2 = tf.identity(inputs)
81+
input_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs_2)
82+
output_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
83+
84+
c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0)
85+
86+
def b(i, out_ta):
87+
new_i = tf.add(i, 1)
88+
x = input_ta.read(i)
89+
x = x + 3
90+
out_ta_new = out_ta.write(i, x)
91+
return new_i, out_ta_new
92+
93+
i_final, out_final = tf.while_loop(c, b, [i, output_ta])
94+
_ = tf.identity(i_final, name="i")
95+
_ = tf.identity(out_final.stack(), name="output_ta")
96+
input_names_with_port = ["input_1:0", "input_2:0"]
97+
98+
output_names_with_port = ["i:0", "output_ta:0"]
99+
self._run_test_case(input_names_with_port, output_names_with_port)
100+
101+
def test_map_fn(self):
102+
def fn0(elem):
103+
res = elem + elem * elem
104+
return res
105+
106+
def fn1(elem):
107+
res = elem[0] * elem[1] + elem[0]
108+
return res
109+
110+
x_val = 100 * np.random.random_sample([2, 10]).astype(np.float32)
111+
y_val = 100 * np.random.random_sample([2, 10]).astype(np.float32)
112+
113+
# test fn0
114+
x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_0")
115+
x_ = tf.identity(x)
116+
res_ = tf.map_fn(fn0, x_, dtype=tf.float32)
117+
_ = tf.identity(res_, name="output_0")
118+
input_names_with_port = ["input_0:0"]
119+
output_names_with_port = ["output_0:0"]
120+
self._run_test_case(input_names_with_port, output_names_with_port)
121+
tf.reset_default_graph()
122+
123+
# test fn1
124+
x = tf.placeholder(tf.float32, shape=x_val.shape, name="input_0")
125+
y = tf.placeholder(tf.float32, shape=y_val.shape, name="input_1")
126+
x_ = tf.identity(x)
127+
y_ = tf.identity(y)
128+
res_ = tf.map_fn(fn1, (x_, y_), dtype=tf.float32)
129+
_ = tf.identity(res_, name="output_0")
130+
input_names_with_port = ["input_0:0", "input_1:0"]
131+
output_names_with_port = ["output_0:0"]
132+
self._run_test_case(input_names_with_port, output_names_with_port)
133+
134+
def test_bidrectional_attention_wrapper_lstm_encoder(self):
135+
size = 30
136+
time_step = 3
137+
input_size = 4
138+
attn_size = size
139+
batch_size = 9
140+
141+
# shape [batch size, time step, size]
142+
# attention_state: usually the output of an RNN encoder.
143+
# This tensor should be shaped `[batch_size, max_time, ...]`
144+
encoder_time_step = time_step
145+
encoder_x_val = np.random.randn(encoder_time_step, input_size).astype('f')
146+
encoder_x_val = np.stack([encoder_x_val] * batch_size)
147+
encoder_x = tf.placeholder(tf.float32, encoder_x_val.shape, name="input_1")
148+
encoder_cell = tf.nn.rnn_cell.LSTMCell(size)
149+
attention_states, _ = tf.nn.dynamic_rnn(encoder_cell, encoder_x, dtype=tf.float32)
150+
# [9, 3, 30], [9, 30]
151+
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(attn_size,
152+
attention_states)
153+
154+
match_input_fn = lambda curr_input, state: tf.concat([curr_input, state], axis=-1)
155+
cell = tf.nn.rnn_cell.LSTMCell(size)
156+
match_cell_fw = tf.contrib.seq2seq.AttentionWrapper(cell,
157+
attention_mechanism,
158+
attention_layer_size=attn_size,
159+
cell_input_fn=match_input_fn,
160+
output_attention=False)
161+
match_cell_bk = tf.contrib.seq2seq.AttentionWrapper(cell,
162+
attention_mechanism,
163+
attention_layer_size=attn_size,
164+
cell_input_fn=match_input_fn,
165+
output_attention=False)
166+
167+
decoder_time_step = 6
168+
decoder_x_val = np.random.randn(decoder_time_step, batch_size, input_size).astype('f')
169+
170+
decoder_x = tf.placeholder(tf.float32, decoder_x_val.shape, name="input_2")
171+
seq_length = tf.placeholder(tf.int32, (batch_size), name="input_3")
172+
(match_output_fw, match_output_bk), (match_state_fw, match_state_bk) = \
173+
tf.nn.bidirectional_dynamic_rnn(cell_fw=match_cell_fw,
174+
cell_bw=match_cell_bk,
175+
inputs=decoder_x,
176+
sequence_length=tf.identity(seq_length),
177+
dtype=tf.float32,
178+
time_major=True)
179+
180+
matched_output = tf.concat([match_output_fw, match_output_bk], axis=-1)
181+
matched_state = tf.concat([match_state_fw.cell_state, match_state_bk.cell_state], -1)
182+
183+
_ = tf.identity(matched_output, name="output_0")
184+
_ = tf.identity(matched_state, name="final_state")
185+
186+
input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
187+
output_names_with_port = ["output_0:0", "final_state:0"]
188+
self._run_test_case(input_names_with_port, output_names_with_port)
189+
190+
def test_dynamic_decode_normal_stop(self):
191+
batch_size = 2
192+
num_units = 4
193+
vocab_size = 5
194+
embedding_size = 3
195+
go_token = 0
196+
end_token = 1
197+
198+
embedding = tf.constant(np.ones([vocab_size, embedding_size], dtype=np.float32))
199+
state_val = np.reshape([np.ones([num_units], dtype=np.float32) * i for i in range(batch_size)],
200+
[batch_size, num_units])
201+
encoder_state = tf.nn.rnn_cell.LSTMStateTuple(state_val, state_val)
202+
203+
cell_initializer = init_ops.constant_initializer(
204+
np.array([[-0.9592235, 0.42451382, 0.7437744, -0.54485345, -0.80763197,
205+
0.19663906, -0.22738314, 0.7762785, 0.7464578, 0.27227187,
206+
0.7661047, 0.3596425, -0.8528242, -0.89316916, -0.48946142,
207+
0.87882376],
208+
[0.86586094, -0.75018406, 0.25992537, -0.69368935, 0.2515502,
209+
-0.26379275, 0.8954313, 0.5759742, -0.7753072, -0.4388857,
210+
0.95751476, -0.82085776, -0.9467752, -0.37055635, -0.18570113,
211+
-0.86504984],
212+
[0.02305841, 0.3850248, 0.893692, -0.6866486, -0.83703446,
213+
-0.9828961, 0.3989377, -0.59993076, 0.5330808, 0.6916566,
214+
0.98468065, -0.6047034, 0.10823512, 0.34599304, -0.7834821,
215+
-0.7852347],
216+
[0.81643987, 0.31507468, -0.51369476, -0.12273741, 0.9701307,
217+
-0.79669356, -0.34496522, -0.88750815, -0.17995334, 0.34707904,
218+
-0.09201193, 0.5363934, -0.87229705, -0.5073328, -0.95894027,
219+
0.5481839],
220+
[-0.84093595, -0.2341497, -0.86047816, 0.43370056, -0.39073753,
221+
0.37730122, 0.48026466, 0.3004985, -0.60727096, 0.9043884,
222+
-0.37619448, 0.22490788, -0.03739262, 0.61672115, 0.478899,
223+
-0.40780973],
224+
[0.31202435, -0.22045255, -0.6087918, 0.95115066, 0.00199413,
225+
-0.688287, -0.1103518, 0.4169519, 0.7913246, -0.9844644,
226+
-0.6193857, 0.38659644, -0.4726901, -0.44781208, -0.5174744,
227+
-0.605911],
228+
[0.66771054, 0.34912825, 0.22297978, -0.4990945, 0.24057317,
229+
-0.5540829, 0.92277217, 0.74939895, -0.35278273, -0.21587133,
230+
-0.28613377, -0.8794241, -0.40119147, 0.67175174, -0.22741508,
231+
0.37898326]], dtype=np.float32))
232+
dense_initializer = init_ops.constant_initializer(
233+
np.array([[0.56177187, -0.6233454, 0.73997784, 0.35032558, 0.6479795],
234+
[0.6831174, -0.34233975, 0.39330363, 0.45177555, -0.49649096],
235+
[-0.98890066, 0.6175642, 0.09800482, -0.6721206, 0.48805737],
236+
[0.19671416, 0.2623148, 0.742548, 0.13555217, 0.56009054]], dtype=np.float32))
237+
238+
cell = tf.nn.rnn_cell.LSTMCell(
239+
num_units=num_units,
240+
initializer=cell_initializer,
241+
state_is_tuple=True)
242+
243+
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
244+
embedding=embedding,
245+
start_tokens=tf.tile([go_token], [batch_size]),
246+
end_token=end_token)
247+
248+
output_layer = tf.layers.Dense(vocab_size, kernel_initializer=dense_initializer)
249+
decoder = tf.contrib.seq2seq.BasicDecoder(
250+
cell=cell,
251+
helper=helper,
252+
initial_state=encoder_state,
253+
output_layer=output_layer)
254+
255+
outputs, state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
256+
decoder=decoder,
257+
maximum_iterations=6)
258+
259+
_ = tf.identity(outputs.rnn_output, name="rnn_output")
260+
_ = tf.identity(outputs.sample_id, name="sample_id")
261+
_ = tf.identity(state, name="state")
262+
_ = tf.identity(sequence_lengths, name="sequence_lengths")
263+
264+
output_names_with_port = [
265+
"rnn_output:0",
266+
# "sample_id:0", # incomplete type support for Transpose on onnxruntime 0.2.1
267+
"state:0",
268+
]
269+
270+
self._run_test_case([], output_names_with_port)
271+
272+
def test_while_loop_in_cond(self):
273+
x_val = np.array([1, 2, 3], dtype=np.float32)
274+
y_val = np.array([4, 5, 6], dtype=np.float32)
275+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
276+
y = tf.placeholder(tf.float32, y_val.shape, name="input_2")
277+
278+
def cond_graph():
279+
b = tf.constant(np.array([0], dtype=np.int32), dtype=tf.int32)
280+
# while_loop
281+
c = lambda y: tf.reduce_any(tf.less(y, 10))
282+
b = lambda i: tf.add(y, 1)
283+
return tf.while_loop(c, b, [y])
284+
285+
res = tf.cond(x[0] < y[0], lambda: x, cond_graph, name="test_cond")
286+
_ = tf.identity(res, name="output")
287+
288+
input_names_with_port = ["input_1:0", "input_2:0"]
289+
output_names_with_port = ["output:0"]
290+
self._run_test_case(input_names_with_port, output_names_with_port)
291+
292+
def test_cond_in_while_loop(self):
293+
i = tf.placeholder(tf.int32, (), name="input_1")
294+
inputs = tf.placeholder(tf.float32, (10,), name="input_2")
295+
296+
inputs_2 = tf.identity(inputs)
297+
input_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs_2)
298+
output_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
299+
300+
c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0)
301+
302+
def b(i, out_ta):
303+
new_i = tf.add(i, 1)
304+
x = input_ta.read(i)
305+
x = tf.cond(x > 0, lambda: x - 1, lambda: x + 3)
306+
out_ta_new = out_ta.write(i, x)
307+
return new_i, out_ta_new
308+
309+
i_final, out_final = tf.while_loop(c, b, [i, output_ta])
310+
_ = tf.identity(i_final, name="i")
311+
_ = tf.identity(out_final.stack(), name="output_ta")
312+
input_names_with_port = ["input_1:0", "input_2:0"]
313+
314+
output_names_with_port = ["i:0", "output_ta:0"]
315+
self._run_test_case(input_names_with_port, output_names_with_port)
316+
317+
318+
if __name__ == "__main__":
319+
unittest_main()

0 commit comments

Comments
 (0)