Skip to content

Commit d7da573

Browse files
authored
Merge pull request #380 from onnx/gs/bert-regression
fix regression for bert, we broke if multiple model outputs came from the same node
2 parents 170dab0 + 4944dd2 commit d7da573

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

tf2onnx/graph.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,20 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
349349
self.reset_nodes(ops)
350350

351351
# add identity node after each output, in case it is renamed during conversion.
352+
nodes_seen = set()
353+
multi_output_nodes = set()
352354
for o in self.outputs:
353355
n = self.get_node_by_output_in_current_graph(o)
356+
if n in nodes_seen:
357+
multi_output_nodes.add(n)
358+
else:
359+
nodes_seen.add(n)
360+
361+
for o in self.outputs:
362+
n = self.get_node_by_output_in_current_graph(o)
363+
# TODO: below doesn't work for nodes with multiple outputs. A work around, keep those intact.
364+
if n in multi_output_nodes:
365+
continue
354366
new_output_name = port_name(n.name + "_" + utils.make_name("raw_output_"))
355367
n_shapes = n.output_shapes
356368
n_dtypes = n.output_dtypes

tf2onnx/tfonnx.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,17 @@ def arg_minmax_op(ctx, node, name, args):
250250

251251
def reduce_op(ctx, node, name, args):
252252
axes_node = node.inputs[1]
253-
input_rank = len(ctx.get_shape(node.input[0]))
254253
axes = axes_node.get_tensor_value()
255254
if np.isscalar(axes):
256255
axes = [axes]
257-
axes = [val + input_rank if val < 0 else val for val in axes]
256+
input_shape = ctx.get_shape(node.input[0])
257+
if input_shape is None:
258+
if any([val < 0 for val in axes]):
259+
raise ValueError("reduce_op: cannot have negative axis because we don't know input rank")
260+
else:
261+
input_rank = len(ctx.get_shape(node.input[0]))
262+
axes = [val + input_rank if val < 0 else val for val in axes]
263+
258264
node.set_attr("axes", axes)
259265
ctx.remove_input(node, node.input[1])
260266
keep_dims = node.get_attr("keep_dims")

0 commit comments

Comments
 (0)