Skip to content

Commit 4cfe46a

Browse files
make topological sort ordered
1 parent 7dff5e7 commit 4cfe46a

File tree

3 files changed

+39
-38
lines changed

3 files changed

+39
-38
lines changed

tests/test_graph.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def test_dropout(self):
144144
with tf.Session() as sess:
145145
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
146146
x2 = tf.placeholder(tf.float32, [1, 3], name="input2")
147-
prop = tf.placeholder(tf.float32, name="prob")
147+
prop = tf.placeholder(tf.float32, (), name="prob")
148148
x_ = tf.add(x1, x2)
149149
x_ = tf.nn.dropout(x_, prop)
150150
x_ = tf.identity(x_, name="output1")
@@ -163,21 +163,22 @@ def test_dropout_2(self):
163163
with tf.Session() as sess:
164164
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
165165
x2 = tf.placeholder(tf.float32, [1, 3], name="input2")
166-
prop = tf.placeholder(tf.float32, name="prob")
166+
prop = tf.placeholder(tf.float32, (), name="prob")
167167
x_ = tf.add(x1, x2)
168168
x_ = tf.nn.dropout(x_, prop)
169169
x_ = tf.identity(x_, name="output1")
170170
x_ = tf.identity(x_, name="output2")
171171
_ = tf.identity(x_, name="output")
172172
g = process_tf_graph(sess.graph, opset=self.config.opset)
173173
actual = onnx_to_graphviz(g)
174-
expected = 'digraph { "dropout/sub/x" [op_type=Const] "sub/x" [op_type=Const] ' \
175-
'prob [op_type=Placeholder shape="[]"] sub [op_type=Sub] "dropout/sub" [op_type=Sub] ' \
176-
'input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[2, 3]"] ' \
177-
'Add [op_type=Add] output1 [op_type=Identity] output2 [op_type=Identity] ' \
178-
'output [op_type=Identity] "sub/x":0 -> sub prob:0 -> sub "dropout/sub/x":0 -> ' \
179-
'"dropout/sub" sub:0 -> "dropout/sub" input1:0 -> Add input2:0 -> Add Add:0 -> ' \
180-
'output1 output1:0 -> output2 output2:0 -> output }'
174+
expected = 'digraph { "sub/x" [op_type=Const] prob [op_type=Placeholder shape="[]"] ' \
175+
'sub [op_type=Sub] input2 [op_type=Placeholder shape="[1, 3]"] ' \
176+
'input1 [op_type=Placeholder shape="[2, 3]"] "dropout/sub/x" [op_type=Const] ' \
177+
'"dropout/sub" [op_type=Sub] Add [op_type=Add] output1 [op_type=Identity] ' \
178+
'output2 [op_type=Identity] output [op_type=Identity] "sub/x":0 -> sub ' \
179+
'prob:0 -> sub "dropout/sub/x":0 -> "dropout/sub" sub:0 -> "dropout/sub" ' \
180+
'input1:0 -> Add input2:0 -> Add Add:0 -> output1 output1:0 -> output2 ' \
181+
'output2:0 -> output }'
181182
self.assertEqual(expected, actual)
182183

183184
def test_add(self):
@@ -214,8 +215,8 @@ def test_reducesum(self):
214215
_ = tf.identity(x_, name="output")
215216
g = process_tf_graph(sess.graph, opset=self.config.opset)
216217
self.assertEqual(
217-
'digraph { Const [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
218-
'Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }',
218+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Sum [op_type=ReduceSum] '
219+
'output [op_type=Identity] Const [op_type=Const] input1:0 -> Sum Sum:0 -> output }',
219220
onnx_to_graphviz(g))
220221

221222
def test_argminmax(self):
@@ -225,7 +226,7 @@ def test_argminmax(self):
225226
_ = tf.identity(x_, name="output")
226227
g = process_tf_graph(sess.graph, opset=self.config.opset)
227228
self.assertEqual(
228-
'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
229+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] "ArgMin/dimension" [op_type=Const] '
229230
'ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }',
230231
onnx_to_graphviz(g))
231232

@@ -276,12 +277,12 @@ def test_conv2d(self):
276277

277278
g = process_tf_graph(sess.graph, opset=self.config.opset)
278279
self.assertEqual(
279-
'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__3 [op_type=Transpose] '
280-
'"kernel/shape" [op_type=Const] kernel__2 [op_type=Cast] k [op_type=Const] '
281-
'kernel [op_type=Reshape] Conv2D__4 [op_type=Transpose] Conv2D [op_type=Conv] '
282-
'Conv2D__5 [op_type=Transpose] output [op_type=Identity] input1:0 -> Conv2D__3 '
283-
'"kernel/shape":0 -> kernel__2 k:0 -> kernel kernel__2:0 -> kernel kernel:0 -> Conv2D__4 '
284-
'Conv2D__3:0 -> Conv2D Conv2D__4:0 -> Conv2D Conv2D:0 -> Conv2D__5 Conv2D__5:0 -> output }',
280+
'digraph { "kernel/shape" [op_type=Const] kernel__2 [op_type=Cast] k [op_type=Const] '
281+
'kernel [op_type=Reshape] input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__4 '
282+
'[op_type=Transpose] Conv2D__3 [op_type=Transpose] Conv2D [op_type=Conv] Conv2D__5 [op_type=Transpose] '
283+
'output [op_type=Identity] "kernel/shape":0 -> kernel__2 k:0 -> kernel kernel__2:0 -> kernel '
284+
'kernel:0 -> Conv2D__4 input1:0 -> Conv2D__3 Conv2D__3:0 -> Conv2D Conv2D__4:0 -> Conv2D Conv2D:0 -> '
285+
'Conv2D__5 Conv2D__5:0 -> output }',
285286
onnx_to_graphviz(g))
286287

287288
def test_squeeze(self):
@@ -313,10 +314,9 @@ def test_reshape(self):
313314
_ = tf.identity(x_, name="output")
314315
g = process_tf_graph(sess.graph, opset=self.config.opset)
315316
self.assertEqual(
316-
'digraph { "Reshape/shape" [op_type=Const] Reshape__2 [op_type=Cast] '
317-
'input1 [op_type=Placeholder shape="[2, 3]"] Reshape [op_type=Reshape] '
318-
'output [op_type=Identity] "Reshape/shape":0 -> Reshape__2 input1:0 -> Reshape '
319-
'Reshape__2:0 -> Reshape Reshape:0 -> output }',
317+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] "Reshape/shape" [op_type=Const] '
318+
'Reshape__2 [op_type=Cast] Reshape [op_type=Reshape] output [op_type=Identity] '
319+
'"Reshape/shape":0 -> Reshape__2 input1:0 -> Reshape Reshape__2:0 -> Reshape Reshape:0 -> output }',
320320
onnx_to_graphviz(g))
321321

322322
def test_custom_rewrite(self):

tests/test_internals.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ def test_insert_node1(self):
8787
result = onnx_to_graphviz(g)
8888
expected = 'digraph { Placeholder__4 [op_type=Placeholder] ' \
8989
'n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \
90-
'n4 [op_type=Add] n5 [op_type=Abs] n5_graph_outputs_Identity__3 [op_type=Identity] ' \
91-
'n6 [op_type=Identity] input -> n1 n1:0 -> n7 n7:0 -> n2 n1:0 -> n3 ' \
92-
'n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 ' \
93-
'n5_raw_output___2:0 -> n6 }'
90+
'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
91+
'n5_graph_outputs_Identity__3 [op_type=Identity] input -> n1 n1:0 -> n7 ' \
92+
'n7:0 -> n2 n1:0 -> n3 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5_raw_output___2:0 -> n6 ' \
93+
'n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 }'
9494
self.assertEqual(expected, result)
9595

9696
def test_insert_node2(self):
@@ -102,9 +102,9 @@ def test_insert_node2(self):
102102
result = onnx_to_graphviz(g)
103103
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n7 [op_type=Abs] ' \
104104
'n3 [op_type=Abs] n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \
105-
'n5_graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
106-
'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 ' \
107-
'n4:0 -> n5 n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 n5_raw_output___2:0 -> n6 }'
105+
'n6 [op_type=Identity] n5_graph_outputs_Identity__3 [op_type=Identity] ' \
106+
'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 ' \
107+
'n5_raw_output___2:0 -> n6 n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 }'
108108
self.assertEqual(expected, result)
109109

110110
def test_remove_input(self):
@@ -116,10 +116,10 @@ def test_remove_input(self):
116116
g.topological_sort(ops)
117117
result = onnx_to_graphviz(g)
118118
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n3 [op_type=Abs] ' \
119-
'n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \
120-
'n5_graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
121-
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> n4 n4:0 -> n5 ' \
122-
'n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 n5_raw_output___2:0 -> n6 }'
119+
'n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
120+
'n5_graph_outputs_Identity__3 [op_type=Identity] input -> n1 n1:0 -> n3 ' \
121+
'n1:0 -> n2 n2:0 -> n4 n4:0 -> n5 n5_raw_output___2:0 -> n6 ' \
122+
'n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 }'
123123
self.assertEqual(expected, result)
124124

125125
def test_rewrite_subgraph(self):
@@ -145,10 +145,9 @@ def test_rewrite_subgraph(self):
145145
result = onnx_to_graphviz(g)
146146
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
147147
'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \
148-
'n5_graph_outputs_Identity__3 [op_type=Identity] n6 [op_type=Identity] ' \
149-
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 ' \
150-
'n3:0 -> ReplacedOp__5 ReplacedOp__5:0 -> n5_graph_outputs_Identity__3 ' \
151-
'ReplacedOp__5:0 -> n6 }'
148+
'n6 [op_type=Identity] n5_graph_outputs_Identity__3 [op_type=Identity] ' \
149+
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 n3:0 -> ReplacedOp__5 ' \
150+
'ReplacedOp__5:0 -> n6 ReplacedOp__5:0 -> n5_graph_outputs_Identity__3 }'
152151
self.assertEqual(expected, result)
153152

154153
def test_match_flipped(self):

tf2onnx/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,8 @@ def copy_shape(self, input_name, output_name):
723723

724724
def topological_sort(self, ops):
725725
"""Topological sort of graph."""
726+
# sort by name, the result will be reversed alphabeta
727+
ops.sort(key=lambda op: op.name)
726728

727729
def _push_stack(stack, node, in_stack):
728730
stack.append(node)
@@ -748,7 +750,7 @@ def _get_unvisited_child(g, node, not_visited):
748750
all_input |= set(implicit_inputs)
749751
# remove those empty inputs
750752
all_input = list(filter(lambda a: a != '', all_input))
751-
for inp in all_input:
753+
for inp in sorted(all_input):
752754
j = self.get_node_by_output(inp)
753755
utils.make_sure(j is not None, "Cannot find node with output {}".format(inp))
754756
if self.parent_graph and j.name not in op_name_to_index:

0 commit comments

Comments
 (0)