@@ -144,7 +144,7 @@ def test_dropout(self):
144
144
with tf .Session () as sess :
145
145
x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
146
146
x2 = tf .placeholder (tf .float32 , [1 , 3 ], name = "input2" )
147
- prop = tf .placeholder (tf .float32 , name = "prob" )
147
+ prop = tf .placeholder (tf .float32 , (), name = "prob" )
148
148
x_ = tf .add (x1 , x2 )
149
149
x_ = tf .nn .dropout (x_ , prop )
150
150
x_ = tf .identity (x_ , name = "output1" )
@@ -163,21 +163,22 @@ def test_dropout_2(self):
163
163
with tf .Session () as sess :
164
164
x1 = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
165
165
x2 = tf .placeholder (tf .float32 , [1 , 3 ], name = "input2" )
166
- prop = tf .placeholder (tf .float32 , name = "prob" )
166
+ prop = tf .placeholder (tf .float32 , (), name = "prob" )
167
167
x_ = tf .add (x1 , x2 )
168
168
x_ = tf .nn .dropout (x_ , prop )
169
169
x_ = tf .identity (x_ , name = "output1" )
170
170
x_ = tf .identity (x_ , name = "output2" )
171
171
_ = tf .identity (x_ , name = "output" )
172
172
g = process_tf_graph (sess .graph , opset = self .config .opset )
173
173
actual = onnx_to_graphviz (g )
174
- expected = 'digraph { "dropout/sub/x" [op_type=Const] "sub/x" [op_type=Const] ' \
175
- 'prob [op_type=Placeholder shape="[]"] sub [op_type=Sub] "dropout/sub" [op_type=Sub] ' \
176
- 'input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[2, 3]"] ' \
177
- 'Add [op_type=Add] output1 [op_type=Identity] output2 [op_type=Identity] ' \
178
- 'output [op_type=Identity] "sub/x":0 -> sub prob:0 -> sub "dropout/sub/x":0 -> ' \
179
- '"dropout/sub" sub:0 -> "dropout/sub" input1:0 -> Add input2:0 -> Add Add:0 -> ' \
180
- 'output1 output1:0 -> output2 output2:0 -> output }'
174
+ expected = 'digraph { "sub/x" [op_type=Const] prob [op_type=Placeholder shape="[]"] ' \
175
+ 'sub [op_type=Sub] input2 [op_type=Placeholder shape="[1, 3]"] ' \
176
+ 'input1 [op_type=Placeholder shape="[2, 3]"] "dropout/sub/x" [op_type=Const] ' \
177
+ '"dropout/sub" [op_type=Sub] Add [op_type=Add] output1 [op_type=Identity] ' \
178
+ 'output2 [op_type=Identity] output [op_type=Identity] "sub/x":0 -> sub ' \
179
+ 'prob:0 -> sub "dropout/sub/x":0 -> "dropout/sub" sub:0 -> "dropout/sub" ' \
180
+ 'input1:0 -> Add input2:0 -> Add Add:0 -> output1 output1:0 -> output2 ' \
181
+ 'output2:0 -> output }'
181
182
self .assertEqual (expected , actual )
182
183
183
184
def test_add (self ):
@@ -214,8 +215,8 @@ def test_reducesum(self):
214
215
_ = tf .identity (x_ , name = "output" )
215
216
g = process_tf_graph (sess .graph , opset = self .config .opset )
216
217
self .assertEqual (
217
- 'digraph { Const [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
218
- 'Sum [op_type=ReduceSum] output [op_type=Identity ] input1:0 -> Sum Sum:0 -> output }' ,
218
+ 'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Sum [op_type=ReduceSum ] '
219
+ 'output [op_type=Identity] Const [op_type=Const ] input1:0 -> Sum Sum:0 -> output }' ,
219
220
onnx_to_graphviz (g ))
220
221
221
222
def test_argminmax (self ):
@@ -225,7 +226,7 @@ def test_argminmax(self):
225
226
_ = tf .identity (x_ , name = "output" )
226
227
g = process_tf_graph (sess .graph , opset = self .config .opset )
227
228
self .assertEqual (
228
- 'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
229
+ 'digraph { input1 [op_type=Placeholder shape="[2, 3]"] "ArgMin/dimension" [op_type=Const ] '
229
230
'ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }' ,
230
231
onnx_to_graphviz (g ))
231
232
@@ -276,12 +277,12 @@ def test_conv2d(self):
276
277
277
278
g = process_tf_graph (sess .graph , opset = self .config .opset )
278
279
self .assertEqual (
279
- 'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__3 [op_type=Transpose ] '
280
- '" kernel/shape" [op_type=Const] kernel__2 [op_type=Cast] k [op_type=Const] '
281
- 'kernel [op_type=Reshape] Conv2D__4 [op_type=Transpose] Conv2D [op_type=Conv] '
282
- 'Conv2D__5 [op_type=Transpose] output [op_type=Identity] input1 :0 -> Conv2D__3 '
283
- '" kernel/shape" :0 -> kernel__2 k :0 -> kernel kernel__2 :0 -> kernel kernel :0 -> Conv2D__4 '
284
- 'Conv2D__3:0 -> Conv2D Conv2D__4:0 -> Conv2D Conv2D:0 -> Conv2D__5 Conv2D__5:0 -> output }' ,
280
+ 'digraph { "kernel/shape" [op_type=Const] kernel__2 [op_type=Cast] k [op_type=Const ] '
281
+ 'kernel [op_type=Reshape] input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__4 '
282
+ '[op_type=Transpose] Conv2D__3 [op_type=Transpose] Conv2D [op_type=Conv] Conv2D__5 [op_type=Transpose ] '
283
+ 'output [op_type=Identity] "kernel/shape" :0 -> kernel__2 k:0 -> kernel kernel__2:0 -> kernel '
284
+ 'kernel:0 -> Conv2D__4 input1 :0 -> Conv2D__3 Conv2D__3 :0 -> Conv2D Conv2D__4 :0 -> Conv2D Conv2D:0 -> '
285
+ 'Conv2D__5 Conv2D__5:0 -> output }' ,
285
286
onnx_to_graphviz (g ))
286
287
287
288
def test_squeeze (self ):
@@ -313,10 +314,9 @@ def test_reshape(self):
313
314
_ = tf .identity (x_ , name = "output" )
314
315
g = process_tf_graph (sess .graph , opset = self .config .opset )
315
316
self .assertEqual (
316
- 'digraph { "Reshape/shape" [op_type=Const] Reshape__2 [op_type=Cast] '
317
- 'input1 [op_type=Placeholder shape="[2, 3]"] Reshape [op_type=Reshape] '
318
- 'output [op_type=Identity] "Reshape/shape":0 -> Reshape__2 input1:0 -> Reshape '
319
- 'Reshape__2:0 -> Reshape Reshape:0 -> output }' ,
317
+ 'digraph { input1 [op_type=Placeholder shape="[2, 3]"] "Reshape/shape" [op_type=Const] '
318
+ 'Reshape__2 [op_type=Cast] Reshape [op_type=Reshape] output [op_type=Identity] '
319
+ '"Reshape/shape":0 -> Reshape__2 input1:0 -> Reshape Reshape__2:0 -> Reshape Reshape:0 -> output }' ,
320
320
onnx_to_graphviz (g ))
321
321
322
322
def test_custom_rewrite (self ):
0 commit comments