diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 0b46d98a4..bafa77556 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -716,9 +716,11 @@ def version_1(cls, ctx, node, **kwargs): # T output = Split(int32 split_dim, T value, @int num_split) # T outputs = Split(T input, @INT axis, @INTS split) split_dims = node.inputs[0].get_tensor_value() + new_split_dims = split_dims + len(node.output_shapes[0]) if split_dims < 0 else split_dims + new_split_dims = 1 if new_split_dims == 3 else new_split_dims ctx.remove_input(node, node.input[0], 0) node.set_attr("num_outputs", node.get_attr_int("num_split")) - node.set_attr("axis", split_dims) + node.set_attr("axis", new_split_dims) @classmethod def version_2(cls, ctx, node, **kwargs): diff --git a/tf2onnx/optimizer/transpose_optimizer.py b/tf2onnx/optimizer/transpose_optimizer.py index 76f897828..9d96ad617 100644 --- a/tf2onnx/optimizer/transpose_optimizer.py +++ b/tf2onnx/optimizer/transpose_optimizer.py @@ -311,7 +311,27 @@ def _switch_transpose_and_node(self, node, trans, update_shape=True): self._g.set_shape(node.output[0], new_shape) self._g.set_shape(trans.output[0], shape) return True - + # this is for the case where node has multiple outputs. e.g. split node. + def _switch_transpose_and_node_with_multiple_outputs(self, node, trans, update_shape=True): + input_index = self._get_input_index_for_trans(node, trans) + for idx,_output in enumerate(node.output): + shape = self._g.get_shape(_output) + nxt_nodes = self._g.find_output_consumers(_output) + if idx == 0: + transpose = trans + self._g.replace_input(node, node.input[input_index], transpose.input[0], input_index) + self._g.replace_input(trans, trans.input[0], _output, 0) + else: + transpose = self._g.make_node("Transpose", [_output], attr={"perm": trans.get_attr_value("perm")}) + for nxt_node in nxt_nodes: + self._g.replace_input(nxt_node, _output, transpose.output[0]) + + if update_shape and shape: + perm_inv = invert_perm(transpose.get_attr_value("perm")) + new_shape = [shape[i] for i in perm_inv] + self._g.set_shape(_output, new_shape) + self._g.set_shape(transpose.output[0], shape) + return True # if return value is True, then it means Transpose is handled as designed # otherwise, it means that we skip handling since it is not in our support set def _handle_nhwc_tranpose(self, trans): @@ -694,6 +714,21 @@ def _split_handler(self, trans, node): new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np) self._g.replace_inputs(node, [node.input[0], new_axes_const.output[0]]) return True + # handling having branches + if len(node.output) > 1: + trans_rank = get_transpose_rank(trans) + axes = node.get_attr_value("axis", 0) + perm = trans.get_attr("perm").ints + axes = [axes + trans_rank if axes < 0 else axes] + if split: + new_axes_np = np.array(split, dtype=np.int64) + new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np) + # [Transpose -> Split -> next_nodes] -> [Split -> Transpose -> next_nodes] + if not self._switch_transpose_and_node_with_multiple_outputs(node, trans, 1): + return False + new_axes = [perm[a] for a in axes] + node.set_attr("axes", new_axes) + return True return False def _unsqueeze_handler(self, trans, node):