Skip to content

Commit 4db17b8

Browse files
author
wayuanho
committed
fix transpose slice bug in opset 10
1 parent f7a95c5 commit 4db17b8

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -552,15 +552,23 @@ def _slice_handler(self, trans, node):
552552
axes = None
553553
if self._g.opset < 10:
554554
axes = node.get_attr("axes").ints
555+
if axes == [0, 1, 2, 3]:
556+
node.set_attr("axes", NCHW_TO_NHWC)
557+
return self._switch_transpose_and_node(node, trans)
555558
else: # in opset 10, axes is input instead of an attribute.
556-
if len(node.inputs) >= 4:
557-
axes_node = node.inputs[3]
558-
if axes_node.is_const():
559-
axes = axes_node.get_tensor_value(as_list=True)
560-
561-
if axes == [0, 1, 2, 3]:
562-
node.set_attr("axes", NCHW_TO_NHWC)
563-
return self._switch_transpose_and_node(node, trans)
559+
if len(node.inputs) >= 4 and node.inputs[3].is_const():
560+
axes = node.inputs[3].get_tensor_value(as_list=True)
561+
if axes == [0, 1, 2, 3]:
562+
# axes node might be shared
563+
new_axes = np.array(NCHW_TO_NHWC, dtype=np.int64)
564+
if self._nodes_has_single_consumer_node([node]):
565+
node.inputs[3].set_tensor_value(new_axes)
566+
else:
567+
new_axes_const = self._g.make_const(
568+
utils.make_name(node.inputs[3].name), new_axes
569+
)
570+
self._g.replace_input(node, node.input[3], new_axes_const.output[0])
571+
return self._switch_transpose_and_node(node, trans)
564572
return False
565573

566574
def _simple_through_handler(self, trans, node):

0 commit comments

Comments
 (0)