@@ -115,7 +115,8 @@ def test_abs(self):
115
115
x_ = tf .abs (x )
116
116
_ = tf .identity (x_ , name = "output" )
117
117
g = process_tf_graph (sess .graph )
118
- self .assertEqual ('digraph { Abs [op_type=Abs] output [op_type=Identity] input:0 -> Abs Abs:0 -> output }' ,
118
+ self .assertEqual ('digraph { input [op_type=Placeholder shape="[2, 3]"]' \
119
+ ' Abs [op_type=Abs] output [op_type=Identity] input:0 -> Abs Abs:0 -> output }' ,
119
120
onnx_to_graphviz (g ))
120
121
121
122
def test_randomuniform (self ):
@@ -154,9 +155,11 @@ def test_dropout(self):
154
155
_ = tf .identity (x_ , name = "output" )
155
156
g = process_tf_graph (sess .graph )
156
157
actual = onnx_to_graphviz (g )
157
- expected = 'digraph { Add [op_type=Add] Dropout__3 [op_type=Dropout] output1 [op_type=Identity] ' \
158
- 'output2 [op_type=Identity] output [op_type=Identity] input1:0 -> Add input2:0 -> ' \
159
- 'Add Add:0 -> Dropout__3 Dropout__3:0 -> output1 output1:0 -> output2 output2:0 -> output }'
158
+ expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
159
+ 'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] Dropout__3 [op_type=Dropout] ' \
160
+ 'output1 [op_type=Identity] output2 [op_type=Identity] output [op_type=Identity] ' \
161
+ 'input1:0 -> Add input2:0 -> Add Add:0 -> Dropout__3 Dropout__3:0 -> output1 ' \
162
+ 'output1:0 -> output2 output2:0 -> output }'
160
163
self .assertEqual (expected , actual )
161
164
162
165
def test_add (self ):
@@ -167,8 +170,8 @@ def test_add(self):
167
170
_ = tf .identity (x_ , name = "output" )
168
171
g = process_tf_graph (sess .graph )
169
172
self .assertEqual (
170
- 'digraph { Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> '
171
- 'Add Add:0 -> output }' ,
173
+ 'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[2, 3]"] '
174
+ 'Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> Add Add:0 -> output }' ,
172
175
onnx_to_graphviz (g ))
173
176
174
177
def test_squareddifference (self ):
@@ -179,7 +182,8 @@ def test_squareddifference(self):
179
182
_ = tf .identity (x_ , name = "output" )
180
183
g = process_tf_graph (sess .graph )
181
184
self .assertEqual (
182
- 'digraph { SquaredDifference [op_type=Sub] SquaredDifference__2 [op_type=Mul] '
185
+ 'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[1, 3]"] '
186
+ 'SquaredDifference [op_type=Sub] SquaredDifference__2 [op_type=Mul] '
183
187
'output [op_type=Identity] input1:0 -> SquaredDifference input2:0 -> SquaredDifference '
184
188
'SquaredDifference:0 -> SquaredDifference__2 SquaredDifference:0 -> SquaredDifference__2 '
185
189
'SquaredDifference__2:0 -> output }' ,
@@ -192,7 +196,8 @@ def test_reducesum(self):
192
196
_ = tf .identity (x_ , name = "output" )
193
197
g = process_tf_graph (sess .graph )
194
198
self .assertEqual (
195
- 'digraph { Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }' ,
199
+ 'digraph { Const [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
200
+ 'Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }' ,
196
201
onnx_to_graphviz (g ))
197
202
198
203
def test_argminmax (self ):
@@ -202,7 +207,8 @@ def test_argminmax(self):
202
207
_ = tf .identity (x_ , name = "output" )
203
208
g = process_tf_graph (sess .graph )
204
209
self .assertEqual (
205
- 'digraph { ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }' ,
210
+ 'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] ' \
211
+ 'ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }' ,
206
212
onnx_to_graphviz (g ))
207
213
208
214
def test_rsqrt (self ):
@@ -212,8 +218,9 @@ def test_rsqrt(self):
212
218
_ = tf .identity (x_ , name = "output" )
213
219
g = process_tf_graph (sess .graph )
214
220
self .assertEqual (
215
- 'digraph { Rsqrt [op_type=Sqrt] Rsqrt__2 [op_type=Reciprocal] output [op_type=Identity] '
216
- 'input1:0 -> Rsqrt Rsqrt:0 -> Rsqrt__2 Rsqrt__2:0 -> output }' ,
221
+ 'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Rsqrt [op_type=Sqrt] '
222
+ 'Rsqrt__2 [op_type=Reciprocal] output [op_type=Identity] input1:0 -> Rsqrt '
223
+ 'Rsqrt:0 -> Rsqrt__2 Rsqrt__2:0 -> output }' ,
217
224
onnx_to_graphviz (g ))
218
225
219
226
def test_relu6 (self ):
@@ -223,7 +230,9 @@ def test_relu6(self):
223
230
_ = tf .identity (x_ , name = "output" )
224
231
g = process_tf_graph (sess .graph )
225
232
self .assertEqual (
226
- 'digraph { Relu6 [op_type=Max] Relu6__4 [op_type=Min] output [op_type=Identity] input1:0 -> Relu6 '
233
+ 'digraph { Relu6__3 [op_type=Const] Relu6__2 [op_type=Const] '
234
+ 'input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Max] '
235
+ 'Relu6__4 [op_type=Min] output [op_type=Identity] input1:0 -> Relu6 '
227
236
'Relu6__2 -> Relu6 Relu6:0 -> Relu6__4 Relu6__3 -> Relu6__4 Relu6__4:0 -> output }' ,
228
237
onnx_to_graphviz (g ))
229
238
@@ -251,10 +260,12 @@ def test_conv2d(self):
251
260
252
261
g = process_tf_graph (sess .graph )
253
262
self .assertEqual (
254
- 'digraph { Conv2D__2 [op_type=Transpose] kernel [op_type=Reshape] Conv2D__3 [op_type=Transpose] '
255
- 'Conv2D [op_type=Conv] Conv2D__4 [op_type=Transpose] output [op_type=Identity] '
256
- 'input1:0 -> Conv2D__2 k:0 -> kernel "kernel/shape":0 -> kernel kernel:0 -> Conv2D__3 '
257
- 'Conv2D__2:0 -> Conv2D Conv2D__3:0 -> Conv2D Conv2D:0 -> Conv2D__4 Conv2D__4:0 -> output }' ,
263
+ 'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__2 [op_type=Transpose] '
264
+ '"kernel/shape" [op_type=Const] k [op_type=Const] kernel [op_type=Reshape] '
265
+ 'Conv2D__3 [op_type=Transpose] Conv2D [op_type=Conv] Conv2D__4 [op_type=Transpose] '
266
+ 'output [op_type=Identity] input1:0 -> Conv2D__2 k:0 -> kernel "kernel/shape":0 -> kernel '
267
+ 'kernel:0 -> Conv2D__3 Conv2D__2:0 -> Conv2D Conv2D__3:0 -> Conv2D '
268
+ 'Conv2D:0 -> Conv2D__4 Conv2D__4:0 -> output }' ,
258
269
onnx_to_graphviz (g ))
259
270
260
271
def test_squeeze (self ):
@@ -264,8 +275,8 @@ def test_squeeze(self):
264
275
_ = tf .identity (x_ , name = "output" )
265
276
g = process_tf_graph (sess .graph )
266
277
self .assertEqual (
267
- 'digraph { Squeeze [op_type=Squeeze] output [op_type=Identity] input1:0 -> Squeeze '
268
- 'Squeeze:0 -> output }' ,
278
+ 'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Squeeze [op_type=Squeeze] ' \
279
+ 'output [op_type=Identity] input1:0 -> Squeeze Squeeze:0 -> output }' ,
269
280
onnx_to_graphviz (g ))
270
281
271
282
def test_cast (self ):
@@ -275,7 +286,8 @@ def test_cast(self):
275
286
_ = tf .identity (x_ , name = "output" )
276
287
g = process_tf_graph (sess .graph )
277
288
self .assertEqual (
278
- 'digraph { Cast [op_type=Cast] output [op_type=Identity] input1:0 -> Cast Cast:0 -> output }' ,
289
+ 'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Cast [op_type=Cast] output [op_type=Identity] ' \
290
+ 'input1:0 -> Cast Cast:0 -> output }' ,
279
291
onnx_to_graphviz (g ))
280
292
281
293
def test_reshape (self ):
@@ -285,7 +297,8 @@ def test_reshape(self):
285
297
_ = tf .identity (x_ , name = "output" )
286
298
g = process_tf_graph (sess .graph )
287
299
self .assertEqual (
288
- 'digraph { Reshape [op_type=Reshape] output [op_type=Identity] input1:0 -> Reshape '
300
+ 'digraph { "Reshape/shape" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
301
+ 'Reshape [op_type=Reshape] output [op_type=Identity] input1:0 -> Reshape '
289
302
'"Reshape/shape":0 -> Reshape Reshape:0 -> output }' ,
290
303
onnx_to_graphviz (g ))
291
304
@@ -308,8 +321,8 @@ def rewrite_test(g, ops):
308
321
_ = tf .identity (x_ , name = "output" )
309
322
g = process_tf_graph (sess .graph , custom_rewriter = [rewrite_test ])
310
323
self .assertEqual (
311
- 'digraph { Add [op_type=Mul] output [op_type=Identity] input1:0 -> '
312
- 'Add input1:0 -> Add Add:0 -> output }' ,
324
+ 'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Mul] '
325
+ 'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }' ,
313
326
onnx_to_graphviz (g ))
314
327
315
328
def test_custom_op (self ):
@@ -333,7 +346,8 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
333
346
custom_op_handlers = {"Print" : print_handler },
334
347
extra_opset = helper .make_opsetid (_TENSORFLOW_DOMAIN , 1 ))
335
348
self .assertEqual (
336
- 'digraph { Print [op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }' ,
349
+ 'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '
350
+ 'output [op_type=Identity] input1:0 -> Print Print:0 -> output }' ,
337
351
onnx_to_graphviz (g ))
338
352
339
353
0 commit comments