Skip to content

Commit 70d94c2

Browse files
fix dropout bug
1 parent 757ed9a commit 70d94c2

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

tf2onnx/graph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,7 @@ def _get_unvisited_child(g, node, not_visited):
687687
all_input = list(filter(lambda a: a != '', all_input))
688688
for inp in all_input:
689689
j = self.get_node_by_output(inp)
690+
utils.make_sure(j is not None, "Cannot find node with output {}".format(inp))
690691
if self.parent_graph and j.name not in op_name_to_index:
691692
# there might be some outer-scoped inputs for an inner Graph.
692693
pass
@@ -754,6 +755,7 @@ def make_graph(self, doc, graph_name="tf2onnx"):
754755
placeholder_default_const_ops = []
755756
for op in placeholder_ops:
756757
if op.type == "PlaceholderWithDefault":
758+
utils.make_sure(op.inputs[0] is not None, "Cannot find node with output {}".format(op.input[0]))
757759
utils.make_sure(op.inputs[0].is_const(),
758760
"non-const default value for PlaceholderWithDefault is not supported.")
759761
# copy the tensor value, set its name to current node's output, add as initializer
@@ -1027,6 +1029,11 @@ def extract_sub_graph_nodes(self, outputs_name, input_checker=None, ignore_unuse
10271029
if node.is_graph_input():
10281030
if node not in res_set:
10291031
res_set.add(node)
1032+
if node.type == "PlaceholderWithDefault" and \
1033+
node.inputs[0].is_const() and \
1034+
node.inputs[0] not in res_set:
1035+
res_set.add(node.inputs[0])
1036+
10301037
return list(res_set)
10311038

10321039
def delete_unused_nodes(self, outputs_name):

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def rewrite_dropout(g, ops):
192192
OpTypePattern('Floor', inputs=[
193193
OpTypePattern('Add', inputs=[
194194
OpTypePattern(None, name="input3"),
195-
OpTypePattern('RandomUniform'),
195+
OpTypePattern('RandomUniform|RandomUniformLike'),
196196
])
197197
]),
198198
])

0 commit comments

Comments
 (0)