Skip to content

Commit 8c47f22

Browse files
authored
Fit split axes same with transpose after optimizer (#1918)
Signed-off-by: Deyu Huang <[email protected]>
1 parent 2d29fbd commit 8c47f22

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _calculate_new_shape(graph, op):
105105
new_shape = [input_shape[p] for p in perm]
106106
return graph.make_const(utils.make_name("new_shape"), np.array(new_shape, dtype=np.int64)).output[0]
107107

108-
# reshape requires tha output shape can only contain one -1, if not some extra op needed.
108+
# reshape requires the output shape can only contain one -1, if not some extra op needed.
109109
input_shape = graph.make_node("Shape", [op.input[0]]).output[0]
110110
indice = graph.make_const(utils.make_name("indice"), np.array(perm, np.int64)).output[0]
111111

@@ -668,9 +668,12 @@ def _concat_handler(self, trans, node):
668668
return False
669669

670670
def _split_handler(self, trans, node):
671-
# Todo: need handle cases where Slit node has more than 1 outputs.
671+
# Todo: need handle cases where Split node has more than 1 outputs.
672672
if self._handle_node_having_branches(trans, node):
673-
node.set_attr("axis", 1)
673+
perm = trans.get_attr_value("perm")
674+
axis = node.get_attr_value("axis", 0)
675+
new_axis = perm[axis]
676+
node.set_attr("axis", new_axis)
674677
return True
675678
return False
676679

0 commit comments

Comments
 (0)