20
20
import tf2onnx
21
21
from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
22
22
from tf2onnx .tfonnx import process_tf_graph
23
- from common import unittest_main
23
+ from common import get_test_config , unittest_main
24
24
25
25
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
26
26
@@ -96,6 +96,8 @@ def setUp(self):
96
96
os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '3'
97
97
tf .logging .set_verbosity (tf .logging .WARN )
98
98
99
+ self .config = get_test_config ()
100
+
99
101
tf2onnx .utils .INTERNAL_NAME = 1
100
102
tf .reset_default_graph ()
101
103
arg = namedtuple ("Arg" , "input inputs outputs verbose continue_on_error" )
@@ -115,7 +117,7 @@ def test_abs(self):
115
117
x = tf .placeholder (tf .float32 , [2 , 3 ], name = "input" )
116
118
x_ = tf .abs (x )
117
119
_ = tf .identity (x_ , name = "output" )
118
- g = process_tf_graph (sess .graph )
120
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
119
121
self .assertEqual ('digraph { input [op_type=Placeholder shape="[2, 3]"]' \
120
122
' Abs [op_type=Abs] output [op_type=Identity] input:0 -> Abs Abs:0 -> output }' ,
121
123
onnx_to_graphviz (g ))
@@ -127,7 +129,7 @@ def test_randomuniform(self):
127
129
x_ = tf .identity (x_ , name = "output1" )
128
130
x_ = tf .identity (x_ , name = "output2" )
129
131
_ = tf .identity (x_ , name = "output" )
130
- g = process_tf_graph (sess .graph )
132
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
131
133
self .assertEqual (
132
134
'digraph { RandomUniform__2 [op_type=RandomUniform shape="[2, 3]"] output1 [op_type=Identity] '
133
135
'output2 [op_type=Identity] output [op_type=Identity] RandomUniform__2:0 -> output1 '
@@ -138,7 +140,7 @@ def test_randomnormal(self):
138
140
with tf .Session () as sess :
139
141
x_ = tf .random_normal ([2 , 3 ], name = "rand" )
140
142
_ = tf .identity (x_ , name = "output" )
141
- g = process_tf_graph (sess .graph )
143
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
142
144
actual = onnx_to_graphviz (g )
143
145
expected = 'digraph { RandomNormal__2 [op_type=RandomNormal shape="[2, 3]"] output [op_type=Identity] ' \
144
146
'RandomNormal__2:0 -> output }'
@@ -154,7 +156,7 @@ def test_dropout(self):
154
156
x_ = tf .identity (x_ , name = "output1" )
155
157
x_ = tf .identity (x_ , name = "output2" )
156
158
_ = tf .identity (x_ , name = "output" )
157
- g = process_tf_graph (sess .graph )
159
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
158
160
actual = onnx_to_graphviz (g )
159
161
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
160
162
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] Dropout__3 [op_type=Dropout] ' \
@@ -169,7 +171,7 @@ def test_add(self):
169
171
x2 = tf .placeholder (tf .float32 , [1 , 3 ], name = "input2" )
170
172
x_ = tf .add (x1 , x2 )
171
173
_ = tf .identity (x_ , name = "output" )
172
- g = process_tf_graph (sess .graph )
174
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
173
175
self .assertEqual (
174
176
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[2, 3]"] '
175
177
'Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> Add Add:0 -> output }' ,
@@ -181,7 +183,7 @@ def test_squareddifference(self):
181
183
x2 = tf .placeholder (tf .float32 , [1 , 3 ], name = "input2" )
182
184
x_ = tf .squared_difference (x1 , x2 )
183
185
_ = tf .identity (x_ , name = "output" )
184
- g = process_tf_graph (sess .graph )
186
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
185
187
self .assertEqual (
186
188
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[1, 3]"] '
187
189
'SquaredDifference [op_type=Sub] SquaredDifference__2 [op_type=Mul] '
@@ -195,7 +197,7 @@ def test_reducesum(self):
195
197
x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
196
198
x_ = tf .reduce_sum (x1 )
197
199
_ = tf .identity (x_ , name = "output" )
198
- g = process_tf_graph (sess .graph )
200
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
199
201
self .assertEqual (
200
202
'digraph { Const [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
201
203
'Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }' ,
@@ -206,7 +208,7 @@ def test_argminmax(self):
206
208
x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
207
209
x_ = tf .argmin (x1 , axis = 0 )
208
210
_ = tf .identity (x_ , name = "output" )
209
- g = process_tf_graph (sess .graph )
211
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
210
212
self .assertEqual (
211
213
'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
212
214
'ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }' ,
@@ -217,7 +219,7 @@ def test_rsqrt(self):
217
219
x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
218
220
x_ = tf .rsqrt (x1 )
219
221
_ = tf .identity (x_ , name = "output" )
220
- g = process_tf_graph (sess .graph )
222
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
221
223
self .assertEqual (
222
224
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Rsqrt [op_type=Sqrt] '
223
225
'Rsqrt__2 [op_type=Reciprocal] output [op_type=Identity] input1:0 -> Rsqrt '
@@ -229,7 +231,7 @@ def test_relu6(self):
229
231
x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
230
232
x_ = tf .nn .relu6 (x1 )
231
233
_ = tf .identity (x_ , name = "output" )
232
- g = process_tf_graph (sess .graph )
234
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
233
235
self .assertEqual (
234
236
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Relu] Relu6__2 [op_type=Clip] '
235
237
'output [op_type=Identity] input1:0 -> Relu6 Relu6:0 -> Relu6__2 Relu6__2:0 -> output }' ,
@@ -257,7 +259,7 @@ def test_conv2d(self):
257
259
sess .run (tf .global_variables_initializer ())
258
260
_ = sess .run (conv , feed_dict = {image_ : image })
259
261
260
- g = process_tf_graph (sess .graph )
262
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
261
263
self .assertEqual (
262
264
'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__3 [op_type=Transpose] '
263
265
'"kernel/shape" [op_type=Const] kernel__2 [op_type=Cast] k [op_type=Const] '
@@ -272,7 +274,7 @@ def test_squeeze(self):
272
274
x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
273
275
x_ = tf .squeeze (x1 )
274
276
_ = tf .identity (x_ , name = "output" )
275
- g = process_tf_graph (sess .graph )
277
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
276
278
self .assertEqual (
277
279
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Squeeze [op_type=Squeeze] '
278
280
'output [op_type=Identity] input1:0 -> Squeeze Squeeze:0 -> output }' ,
@@ -283,7 +285,7 @@ def test_cast(self):
283
285
x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
284
286
x_ = tf .cast (x1 , tf .int32 )
285
287
_ = tf .identity (x_ , name = "output" )
286
- g = process_tf_graph (sess .graph )
288
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
287
289
self .assertEqual (
288
290
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Cast [op_type=Cast] output [op_type=Identity] '
289
291
'input1:0 -> Cast Cast:0 -> output }' ,
@@ -294,7 +296,7 @@ def test_reshape(self):
294
296
x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
295
297
x_ = tf .reshape (x1 , [3 , 2 ])
296
298
_ = tf .identity (x_ , name = "output" )
297
- g = process_tf_graph (sess .graph )
299
+ g = process_tf_graph (sess .graph , opset = self . config . opset )
298
300
self .assertEqual (
299
301
'digraph { "Reshape/shape" [op_type=Const] Reshape__2 [op_type=Cast] '
300
302
'input1 [op_type=Placeholder shape="[2, 3]"] Reshape [op_type=Reshape] '
@@ -319,7 +321,7 @@ def rewrite_test(g, ops):
319
321
x = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
320
322
x_ = tf .add (x , x )
321
323
_ = tf .identity (x_ , name = "output" )
322
- g = process_tf_graph (sess .graph , custom_rewriter = [rewrite_test ])
324
+ g = process_tf_graph (sess .graph , opset = self . config . opset , custom_rewriter = [rewrite_test ])
323
325
self .assertEqual (
324
326
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Mul] '
325
327
'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }' ,
@@ -345,6 +347,7 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
345
347
_ = tf .identity (x_ , name = "output" )
346
348
g = process_tf_graph (sess .graph ,
347
349
custom_op_handlers = {"Print" : (print_handler , ["Identity" , "mode" ])},
350
+ opset = self .config .opset ,
348
351
extra_opset = helper .make_opsetid (_TENSORFLOW_DOMAIN , 1 ))
349
352
self .assertEqual (
350
353
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '
0 commit comments