Skip to content

Commit 11c7e39

Browse files
committed
fix regression for bert, we broke if multiple model outputs came from the same node
1 parent cec8368 commit 11c7e39

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

tf2onnx/graph.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,20 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
338338
self.reset_nodes(ops)
339339

340340
# add identity node after each output, in case it is renamed during conversion.
341+
nodes_seen = set()
342+
multi_output_nodes = set()
341343
for o in self.outputs:
342344
n = self.get_node_by_output_in_current_graph(o)
345+
if n in nodes_seen:
346+
multi_output_nodes.add(n)
347+
else:
348+
nodes_seen.add(n)
349+
350+
for o in self.outputs:
351+
n = self.get_node_by_output_in_current_graph(o)
352+
# TODO: below doesn't work for nodes with multiple outputs. A work around, keep those intact.
353+
if n in multi_output_nodes:
354+
continue
343355
new_output_name = port_name(n.name + "_" + utils.make_name("raw_output_"))
344356
n_shapes = n.output_shapes
345357
n_dtypes = n.output_dtypes

tf2onnx/tfonnx.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,18 @@ def arg_minmax_op(ctx, node, name, args):
250250

251251
def reduce_op(ctx, node, name, args):
252252
axes_node = node.inputs[1]
253-
input_rank = len(ctx.get_shape(node.input[0]))
254253
axes = axes_node.get_tensor_value()
255254
if np.isscalar(axes):
256255
axes = [axes]
257-
axes = [val + input_rank if val < 0 else val for val in axes]
256+
input_shape = ctx.get_shape(node.input[0])
257+
if input_shape is None:
258+
for val in axes:
259+
if val < 0:
260+
raise ValueError("reduce_op: can have negative axis because we don't know input rank")
261+
else:
262+
input_rank = len(ctx.get_shape(node.input[0]))
263+
axes = [val + input_rank if val < 0 else val for val in axes]
264+
258265
node.set_attr("axes", axes)
259266
ctx.remove_input(node, node.input[1])
260267
keep_dims = node.get_attr("keep_dims")

0 commit comments

Comments
 (0)