Skip to content

Commit c66b232

Browse files
committed
Fix unit tests
1 parent a138ccb commit c66b232

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

tf2onnx/tfonnx.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,40 +42,38 @@ def rewrite_constant_fold(g, ops):
4242
tensorflow missed something, make another pass over the graph and fix want we care about.
4343
"""
4444
func_map = {
45-
# "Add": np.add,
46-
# "GreaterEqual": np.greater_equal,
47-
# "Cast": np.cast,
45+
"Add": np.add,
46+
"GreaterEqual": np.greater_equal,
47+
"Cast": np.cast,
4848
"ConcatV2": np.concatenate,
49-
# "Less": np.less,
50-
# "ListDiff": np.setdiff1d,
51-
# "Mul": np.multiply,
52-
# "Pack": np.stack,
53-
# "Range": np.arange,
54-
# "Sqrt": np.sqrt,
55-
# "Sub": np.subtract,
49+
"Less": np.less,
50+
"ListDiff": np.setdiff1d,
51+
"Mul": np.multiply,
52+
"Pack": np.stack,
53+
"Range": np.arange,
54+
"Sqrt": np.sqrt,
55+
"Sub": np.subtract,
5656
}
57-
ref_cnt_per_node = {}
58-
for idx, op in enumerate(ops):
59-
for op_input in op.inputs:
60-
if op_input.name not in ref_cnt_per_node:
61-
ref_cnt_per_node[op_input.name] = 0
62-
ref_cnt_per_node[op_input.name] += 1
6357

6458
# pylint: disable=too-many-nested-blocks
6559
keep_looking = True
6660
while keep_looking:
6761
keep_looking = False
6862
for idx, op in enumerate(ops):
6963
func = func_map.get(op.type)
70-
if func is None:
71-
continue
64+
if func is None: continue
65+
if set(op.output) & set(g.outputs): continue
7266
try:
7367
inputs = []
68+
skip = False
7469
for node in op.inputs:
7570
if not node.is_const():
71+
skip = True
7672
break
7773
inputs.append(node.get_tensor_value(as_list=False))
7874

75+
if skip: continue
76+
7977
logger.debug("op name %s, %s, %s", op.name, len(op.input), len(inputs))
8078
if inputs and len(op.input) == len(inputs):
8179
logger.info("folding node type=%s, name=%s" % (op.type, op.name))
@@ -109,18 +107,15 @@ def rewrite_constant_fold(g, ops):
109107
old_node_name = op.name
110108
logger.debug("create const node [%s] replacing [%s]", new_node_name, old_node_name)
111109
ops[idx] = g.make_const(new_node_name, val)
112-
ref_cnt_per_node[new_node_name] = ref_cnt_per_node[old_node_name]
113110

114111
logger.debug("replace old output [%s] with new output [%s]", old_output_name, new_output_name)
115112
# need to re-write the consumers input name to use the const name
116113
consumers = g.find_output_consumers(old_output_name)
117114
if consumers:
118115
for consumer in consumers:
119116
g.replace_input(consumer, old_output_name, new_output_name)
120-
for node in op.inputs:
121-
ref_cnt_per_node[node.name] -= 1
122-
if ref_cnt_per_node[node.name] == 0:
123-
g.remove_node(node.name)
117+
g.remove_node(old_node_name)
118+
124119
# keep looking until there is nothing we can fold.
125120
# We keep the graph in topological order so if we folded,
126121
# the result might help a following op.

0 commit comments

Comments
 (0)