Skip to content

Commit a6b141b

Browse files
Fix bug with BatchMatMulV2 when it has the same input twice (#1483)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent d6385a1 commit a6b141b

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

tf2onnx/graph.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,7 @@ def remove_input(self, node, to_be_removed, input_index=None):
12581258

12591259
# don't remove output from parent since others might depend on it
12601260

1261-
def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=None, **kwargs):
1261+
def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=None, input_index=None, **kwargs):
12621262
"""Create and insert a new node into the graph.
12631263
Args:
12641264
node: we want to replace the input for this node
@@ -1279,10 +1279,13 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=
12791279
input_name = [input_name]
12801280

12811281
new_node = self.make_node(op_type, input_name, attr=kwargs, outputs=[new_output], name=name, domain=domain)
1282-
for i, n in enumerate(node.input):
1283-
if n == input_name[0]:
1284-
self.replace_input(node, node.input[i], new_output, i)
1285-
break
1282+
if input_index is None:
1283+
for i, n in enumerate(node.input):
1284+
if n == input_name[0]:
1285+
self.replace_input(node, node.input[i], new_output, i)
1286+
break
1287+
else:
1288+
self.replace_input(node, node.input[input_index], new_output, input_index)
12861289
return new_node
12871290

12881291
def insert_node_on_output(self, node, output_name=None):

tf2onnx/onnx_opset/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def version_1(cls, ctx, node, **kwargs):
374374
tmp = perm[-1]
375375
perm[-1] = perm[-2]
376376
perm[-2] = tmp
377-
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=perm)
377+
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], input_index=0, perm=perm)
378378

379379
if transpose_b != 0:
380380
shape = ctx.get_shape(node.input[1])
@@ -383,7 +383,7 @@ def version_1(cls, ctx, node, **kwargs):
383383
tmp = perm[-1]
384384
perm[-1] = perm[-2]
385385
perm[-2] = tmp
386-
ctx.insert_new_node_on_input(node, "Transpose", node.input[1], perm=perm)
386+
ctx.insert_new_node_on_input(node, "Transpose", node.input[1], input_index=1, perm=perm)
387387

388388
unsupported = ["a_is_sparse", "b_is_sparse"]
389389
for i in unsupported:

0 commit comments

Comments
 (0)