Skip to content

Commit f333ce5

Browse files
authored
Merge pull request #929 from jignparm/jignparm/finetunedbert
Multiple fixes for Bert Model (fine-tuned)
2 parents 4b6b8d1 + 49942b1 commit f333ce5

File tree

3 files changed

+9
-16
lines changed

3 files changed

+9
-16
lines changed

tf2onnx/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,9 @@ def make_onnx_graph_io(self, ids):
10671067
shape = self.get_shape(name)
10681068

10691069
utils.make_sure(dtype is not None, "missing output dtype for " + name)
1070-
utils.make_sure(shape is not None, "missing output shape for " + name)
1070+
# TODO: allow None output shape or not? e.g. shape=(?,)
1071+
#utils.make_sure(shape is not None, "missing output shape for " + name)
1072+
if shape is None: logger.warning("missing output shape for %s", name)
10711073

10721074
v = utils.make_onnx_inputs_outputs(name, dtype, shape)
10731075
tensor_value_infos.append(v)

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def version_1(cls, ctx, node, **kwargs):
196196
shape = ctx.get_shape(node.input[0])
197197
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
198198
axis = [i for i, j in enumerate(shape) if j == 1]
199+
if not axis: axis = [0]
199200
node.set_attr("axes", axis)
200201

201202
@classmethod

tf2onnx/tfonnx.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,15 @@ def rewrite_constant_fold(g, ops):
5454
"Sqrt": np.sqrt,
5555
"Sub": np.subtract,
5656
}
57-
ref_cnt_per_node = {}
58-
for idx, op in enumerate(ops):
59-
for op_input in op.inputs:
60-
if op_input.name not in ref_cnt_per_node:
61-
ref_cnt_per_node[op_input.name] = 0
62-
ref_cnt_per_node[op_input.name] += 1
6357

6458
# pylint: disable=too-many-nested-blocks
6559
keep_looking = True
6660
while keep_looking:
6761
keep_looking = False
6862
for idx, op in enumerate(ops):
6963
func = func_map.get(op.type)
70-
if func is None:
71-
continue
64+
if func is None: continue
65+
if set(op.output) & set(g.outputs): continue
7266
try:
7367
inputs = []
7468
for node in op.inputs:
@@ -109,18 +103,14 @@ def rewrite_constant_fold(g, ops):
109103
old_node_name = op.name
110104
logger.debug("create const node [%s] replacing [%s]", new_node_name, old_node_name)
111105
ops[idx] = g.make_const(new_node_name, val)
112-
ref_cnt_per_node[new_node_name] = ref_cnt_per_node[old_node_name]
113106

114107
logger.debug("replace old output [%s] with new output [%s]", old_output_name, new_output_name)
115108
# need to re-write the consumers input name to use the const name
116109
consumers = g.find_output_consumers(old_output_name)
117110
if consumers:
118111
for consumer in consumers:
119112
g.replace_input(consumer, old_output_name, new_output_name)
120-
for node in op.inputs:
121-
ref_cnt_per_node[node.name] -= 1
122-
if ref_cnt_per_node[node.name] == 0:
123-
g.remove_node(node.name)
113+
124114
# keep looking until there is nothing we can fold.
125115
# We keep the graph in topological order so if we folded,
126116
# the result might help a following op.
@@ -459,8 +449,8 @@ def compat_handler(ctx, node, **kwargs):
459449

460450
# pre-processing graph rewrites
461451
# bi-directional re-writer should be placed after single directional re-writer
462-
rewriters = [rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten, rewrite_gemm,
463-
rewrite_random_uniform, rewrite_random_uniform_fold_const,
452+
rewriters = [rewrite_constant_fold, rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten,
453+
rewrite_gemm, rewrite_random_uniform, rewrite_random_uniform_fold_const,
464454
rewrite_random_normal, rewrite_dropout, rewrite_eye,
465455
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
466456
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,

0 commit comments

Comments
 (0)