Skip to content

Commit f3c77e9

Browse files
committed
feat : add switch transpose and node in multiple output case
Signed-off-by: dongryeol.lee <[email protected]>
1 parent 9d65c68 commit f3c77e9

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,27 @@ def _switch_transpose_and_node(self, node, trans, update_shape=True):
311311
self._g.set_shape(node.output[0], new_shape)
312312
self._g.set_shape(trans.output[0], shape)
313313
return True
314-
314+
# this is for the case where node has multiple outputs. e.g. split node.
315+
def _switch_transpose_and_node_with_multiple_outputs(self, node, trans, update_shape=True):
316+
input_index = self._get_input_index_for_trans(node, trans)
317+
for idx,_output in enumerate(node.output):
318+
shape = self._g.get_shape(_output)
319+
nxt_nodes = self._g.find_output_consumers(_output)
320+
if idx == 0:
321+
transpose = trans
322+
self._g.replace_input(node, node.input[input_index], transpose.input[0], input_index)
323+
self._g.replace_input(trans, trans.input[0], _output, 0)
324+
else:
325+
transpose = self._g.make_node("Transpose", [_output], attr={"perm": trans.get_attr_value("perm")})
326+
for nxt_node in nxt_nodes:
327+
self._g.replace_input(nxt_node, _output, transpose.output[0])
328+
329+
if update_shape and shape:
330+
perm_inv = invert_perm(transpose.get_attr_value("perm"))
331+
new_shape = [shape[i] for i in perm_inv]
332+
self._g.set_shape(_output, new_shape)
333+
self._g.set_shape(transpose.output[0], shape)
334+
return True
315335
# if return value is True, then it means Transpose is handled as designed
316336
# otherwise, it means that we skip handling since it is not in our support set
317337
def _handle_nhwc_tranpose(self, trans):
@@ -694,6 +714,21 @@ def _split_handler(self, trans, node):
694714
new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np)
695715
self._g.replace_inputs(node, [node.input[0], new_axes_const.output[0]])
696716
return True
717+
# handling having branches
718+
if len(node.output) > 1:
719+
trans_rank = get_transpose_rank(trans)
720+
axes = node.get_attr_value("axis", 0)
721+
perm = trans.get_attr("perm").ints
722+
axes = [axes + trans_rank if axes < 0 else axes]
723+
if split:
724+
new_axes_np = np.array(split, dtype=np.int64)
725+
new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np)
726+
# [Transpose -> Split -> next_nodes] -> [Split -> Transpose -> next_nodes]
727+
if not self._switch_transpose_and_node_with_multiple_outputs(node, trans, 1):
728+
return False
729+
new_axes = [perm[a] for a in axes]
730+
node.set_attr("axes", new_axes)
731+
return True
697732
return False
698733

699734
def _unsqueeze_handler(self, trans, node):

0 commit comments

Comments
 (0)