Skip to content

Commit 16d0c28

Browse files
committed
fix test_graph
1 parent dd3b69a commit 16d0c28

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

tests/test_graph.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import tf2onnx
2121
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2222
from tf2onnx.tfonnx import process_tf_graph
23-
from common import unittest_main
23+
from common import get_test_config, unittest_main
2424

2525
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
2626

@@ -96,6 +96,8 @@ def setUp(self):
9696
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
9797
tf.logging.set_verbosity(tf.logging.WARN)
9898

99+
self.config = get_test_config()
100+
99101
tf2onnx.utils.INTERNAL_NAME = 1
100102
tf.reset_default_graph()
101103
arg = namedtuple("Arg", "input inputs outputs verbose continue_on_error")
@@ -115,7 +117,7 @@ def test_abs(self):
115117
x = tf.placeholder(tf.float32, [2, 3], name="input")
116118
x_ = tf.abs(x)
117119
_ = tf.identity(x_, name="output")
118-
g = process_tf_graph(sess.graph)
120+
g = process_tf_graph(sess.graph, opset=self.config.opset)
119121
self.assertEqual('digraph { input [op_type=Placeholder shape="[2, 3]"]' \
120122
' Abs [op_type=Abs] output [op_type=Identity] input:0 -> Abs Abs:0 -> output }',
121123
onnx_to_graphviz(g))
@@ -127,7 +129,7 @@ def test_randomuniform(self):
127129
x_ = tf.identity(x_, name="output1")
128130
x_ = tf.identity(x_, name="output2")
129131
_ = tf.identity(x_, name="output")
130-
g = process_tf_graph(sess.graph)
132+
g = process_tf_graph(sess.graph, opset=self.config.opset)
131133
self.assertEqual(
132134
'digraph { RandomUniform__2 [op_type=RandomUniform shape="[2, 3]"] output1 [op_type=Identity] '
133135
'output2 [op_type=Identity] output [op_type=Identity] RandomUniform__2:0 -> output1 '
@@ -138,7 +140,7 @@ def test_randomnormal(self):
138140
with tf.Session() as sess:
139141
x_ = tf.random_normal([2, 3], name="rand")
140142
_ = tf.identity(x_, name="output")
141-
g = process_tf_graph(sess.graph)
143+
g = process_tf_graph(sess.graph, opset=self.config.opset)
142144
actual = onnx_to_graphviz(g)
143145
expected = 'digraph { RandomNormal__2 [op_type=RandomNormal shape="[2, 3]"] output [op_type=Identity] ' \
144146
'RandomNormal__2:0 -> output }'
@@ -154,7 +156,7 @@ def test_dropout(self):
154156
x_ = tf.identity(x_, name="output1")
155157
x_ = tf.identity(x_, name="output2")
156158
_ = tf.identity(x_, name="output")
157-
g = process_tf_graph(sess.graph)
159+
g = process_tf_graph(sess.graph, opset=self.config.opset)
158160
actual = onnx_to_graphviz(g)
159161
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
160162
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] Dropout__3 [op_type=Dropout] ' \
@@ -169,7 +171,7 @@ def test_add(self):
169171
x2 = tf.placeholder(tf.float32, [1, 3], name="input2")
170172
x_ = tf.add(x1, x2)
171173
_ = tf.identity(x_, name="output")
172-
g = process_tf_graph(sess.graph)
174+
g = process_tf_graph(sess.graph, opset=self.config.opset)
173175
self.assertEqual(
174176
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[2, 3]"] '
175177
'Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> Add Add:0 -> output }',
@@ -181,7 +183,7 @@ def test_squareddifference(self):
181183
x2 = tf.placeholder(tf.float32, [1, 3], name="input2")
182184
x_ = tf.squared_difference(x1, x2)
183185
_ = tf.identity(x_, name="output")
184-
g = process_tf_graph(sess.graph)
186+
g = process_tf_graph(sess.graph, opset=self.config.opset)
185187
self.assertEqual(
186188
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[1, 3]"] '
187189
'SquaredDifference [op_type=Sub] SquaredDifference__2 [op_type=Mul] '
@@ -195,7 +197,7 @@ def test_reducesum(self):
195197
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
196198
x_ = tf.reduce_sum(x1)
197199
_ = tf.identity(x_, name="output")
198-
g = process_tf_graph(sess.graph)
200+
g = process_tf_graph(sess.graph, opset=self.config.opset)
199201
self.assertEqual(
200202
'digraph { Const [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
201203
'Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }',
@@ -206,7 +208,7 @@ def test_argminmax(self):
206208
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
207209
x_ = tf.argmin(x1, axis=0)
208210
_ = tf.identity(x_, name="output")
209-
g = process_tf_graph(sess.graph)
211+
g = process_tf_graph(sess.graph, opset=self.config.opset)
210212
self.assertEqual(
211213
'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
212214
'ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }',
@@ -217,7 +219,7 @@ def test_rsqrt(self):
217219
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
218220
x_ = tf.rsqrt(x1)
219221
_ = tf.identity(x_, name="output")
220-
g = process_tf_graph(sess.graph)
222+
g = process_tf_graph(sess.graph, opset=self.config.opset)
221223
self.assertEqual(
222224
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Rsqrt [op_type=Sqrt] '
223225
'Rsqrt__2 [op_type=Reciprocal] output [op_type=Identity] input1:0 -> Rsqrt '
@@ -229,7 +231,7 @@ def test_relu6(self):
229231
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
230232
x_ = tf.nn.relu6(x1)
231233
_ = tf.identity(x_, name="output")
232-
g = process_tf_graph(sess.graph)
234+
g = process_tf_graph(sess.graph, opset=self.config.opset)
233235
self.assertEqual(
234236
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Relu] Relu6__2 [op_type=Clip] '
235237
'output [op_type=Identity] input1:0 -> Relu6 Relu6:0 -> Relu6__2 Relu6__2:0 -> output }',
@@ -257,7 +259,7 @@ def test_conv2d(self):
257259
sess.run(tf.global_variables_initializer())
258260
_ = sess.run(conv, feed_dict={image_: image})
259261

260-
g = process_tf_graph(sess.graph)
262+
g = process_tf_graph(sess.graph, opset=self.config.opset)
261263
self.assertEqual(
262264
'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__3 [op_type=Transpose] '
263265
'"kernel/shape" [op_type=Const] kernel__2 [op_type=Cast] k [op_type=Const] '
@@ -272,7 +274,7 @@ def test_squeeze(self):
272274
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
273275
x_ = tf.squeeze(x1)
274276
_ = tf.identity(x_, name="output")
275-
g = process_tf_graph(sess.graph)
277+
g = process_tf_graph(sess.graph, opset=self.config.opset)
276278
self.assertEqual(
277279
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Squeeze [op_type=Squeeze] '
278280
'output [op_type=Identity] input1:0 -> Squeeze Squeeze:0 -> output }',
@@ -283,7 +285,7 @@ def test_cast(self):
283285
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
284286
x_ = tf.cast(x1, tf.int32)
285287
_ = tf.identity(x_, name="output")
286-
g = process_tf_graph(sess.graph)
288+
g = process_tf_graph(sess.graph, opset=self.config.opset)
287289
self.assertEqual(
288290
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Cast [op_type=Cast] output [op_type=Identity] '
289291
'input1:0 -> Cast Cast:0 -> output }',
@@ -294,7 +296,7 @@ def test_reshape(self):
294296
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
295297
x_ = tf.reshape(x1, [3, 2])
296298
_ = tf.identity(x_, name="output")
297-
g = process_tf_graph(sess.graph)
299+
g = process_tf_graph(sess.graph, opset=self.config.opset)
298300
self.assertEqual(
299301
'digraph { "Reshape/shape" [op_type=Const] Reshape__2 [op_type=Cast] '
300302
'input1 [op_type=Placeholder shape="[2, 3]"] Reshape [op_type=Reshape] '
@@ -319,7 +321,7 @@ def rewrite_test(g, ops):
319321
x = tf.placeholder(tf.float32, [2, 3], name="input1")
320322
x_ = tf.add(x, x)
321323
_ = tf.identity(x_, name="output")
322-
g = process_tf_graph(sess.graph, custom_rewriter=[rewrite_test])
324+
g = process_tf_graph(sess.graph, opset=self.config.opset, custom_rewriter=[rewrite_test])
323325
self.assertEqual(
324326
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Mul] '
325327
'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }',
@@ -345,6 +347,7 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
345347
_ = tf.identity(x_, name="output")
346348
g = process_tf_graph(sess.graph,
347349
custom_op_handlers={"Print": (print_handler, ["Identity", "mode"])},
350+
opset=self.config.opset,
348351
extra_opset=helper.make_opsetid(_TENSORFLOW_DOMAIN, 1))
349352
self.assertEqual(
350353
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '

0 commit comments

Comments
 (0)