Skip to content

Commit 37d0c28

Browse files
Optimize Transpose->Slice regardless of Slice's axes (#1330)
* Optimize Transpose->Slice regardless of Slice's axes Signed-off-by: Mateusz Tabaka <[email protected]> * Rename test_transpose_slice_10 test case to test_transpose_slice_opset_10 Signed-off-by: Mateusz Tabaka <[email protected]> * Parameterize axes for test_transpose_slice Signed-off-by: Mateusz Tabaka <[email protected]> Co-authored-by: TomWildenhain-Microsoft <[email protected]>
1 parent e3bb930 commit 37d0c28

File tree

2 files changed

+80
-28
lines changed

2 files changed

+80
-28
lines changed

tests/test_optimizers.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
from onnx import helper, numpy_helper, TensorProto, OperatorSetIdProto
1313
from parameterized import parameterized
14+
1415
from backend_test_base import Tf2OnnxBackendTestBase
1516
from common import unittest_main, group_nodes_by_type, check_opset_min_version, check_opset_max_version, get_test_config
1617
from tf2onnx import utils, constants
@@ -309,21 +310,31 @@ def test_transpose_dequantize_with_axis(self, shape, perm_input, perm_output):
309310
model_proto, remaining_transpose_num=0)
310311

311312
@parameterized.expand([
312-
((2, 3, 4, 5), [1, 2, 1, 2], (1, 2, 2, 1), [0, 2, 3, 1], [0, 3, 1, 2]),
313-
((2, 3, 4, 5, 6), [1, 2, 1, 2, 1], (1, 1, 2, 1, 2), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
313+
([2, 3, 4, 5], [1, 2, 1, 2], [1], [0, 2, 3, 1], [0, 3, 1, 2]),
314+
([2, 3, 4, 5], [1, 2, 1, 2], [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
315+
([2, 3, 4, 5], [1, 2, 1, 2], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
316+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
317+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2, 3], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
318+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 1, 2, 3, 4], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
314319
])
315-
@check_opset_min_version(10, "Slice in opset 10 can accept dymaic 'start' and 'ends'")
316-
def test_transpose_slice(self, input_shape, slice_size, output_shape, perm_input, perm_output):
317-
starts = np.array([0] * len(input_shape), dtype=np.int64)
318-
ends = np.array(slice_size, dtype=np.int64)
319-
axes = np.array(list(range(len(input_shape))), dtype=np.int64)
320+
@check_opset_max_version(9, "Slice in opset 9 and takes 'axes, 'start' and 'ends' as attributes")
321+
def test_transpose_slice(self, input_shape, slice_size, axes, perm_input, perm_output):
322+
axes = np.array(axes, dtype=np.int64)
323+
starts = np.array([0] * axes.size, dtype=np.int64)
324+
ends = []
325+
for i in range(axes.size):
326+
ends.append(slice_size[axes[i]])
327+
ends = np.array(ends, dtype=np.int64)
328+
output_shape = input_shape.copy()
329+
for axis in axes:
330+
output_shape[perm_input[axis]] = slice_size[axis]
320331
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
321-
node2 = helper.make_node("Slice", ["Y", "starts", "ends", "axes"], ["Z"], name="relu")
332+
node2 = helper.make_node("Slice", ["Y"], ["Z"], starts=starts, ends=ends, axes=axes, name="slice")
322333
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
323334

324335
graph = helper.make_graph(
325336
[node1, node2, node3],
326-
"relu-test",
337+
"slice-test",
327338
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
328339
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
329340
[
@@ -337,6 +348,45 @@ def test_transpose_slice(self, input_shape, slice_size, output_shape, perm_input
337348
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
338349
model_proto, remaining_transpose_num=0)
339350

351+
@parameterized.expand([
352+
([2, 3, 4, 5], [1, 2, 1, 2], [1], [0, 2, 3, 1], [0, 3, 1, 2]),
353+
([2, 3, 4, 5], [1, 2, 1, 2], [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
354+
([2, 3, 4, 5], [1, 2, 1, 2], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
355+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
356+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2, 3], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
357+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 1, 2, 3, 4], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
358+
])
359+
@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'start' and 'ends'")
360+
def test_transpose_slice_opset_10(self, input_shape, slice_size, axes, perm_input, perm_output):
361+
axes = np.array(axes, dtype=np.int32)
362+
starts = np.array([0] * axes.size, dtype=np.int32)
363+
ends = []
364+
for i in range(axes.size):
365+
ends.append(slice_size[axes[i]])
366+
ends = np.array(ends, dtype=np.int32)
367+
output_shape = input_shape.copy()
368+
for axis in axes:
369+
output_shape[perm_input[axis]] = slice_size[axis]
370+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
371+
node2 = helper.make_node("Slice", ["Y", "starts", "ends", "axes"], ["Z"], name="slice")
372+
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
373+
374+
graph = helper.make_graph(
375+
[node1, node2, node3],
376+
"slice-test",
377+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
378+
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
379+
[
380+
helper.make_tensor("starts", TensorProto.INT32, starts.shape, starts),
381+
helper.make_tensor("ends", TensorProto.INT32, ends.shape, ends),
382+
helper.make_tensor("axes", TensorProto.INT32, axes.shape, axes)
383+
]
384+
)
385+
386+
model_proto = self.make_model(graph, producer_name="onnx-tests")
387+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
388+
model_proto, remaining_transpose_num=0)
389+
340390
@parameterized.expand([
341391
((2, 3, 4, 5), (2, 4, 5, 3), [0, 2, 3, 1], [0, 3, 1, 2]),
342392
((2, 3, 4, 5, 6), (2, 4, 5, 6, 3), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -712,25 +712,27 @@ def _slice_handler(self, trans, node):
712712
if not axes_values:
713713
return False
714714
axes = axes_values.ints
715-
if axes == list(range(trans_rank)):
716-
new_axes = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
717-
node.set_attr("axes", new_axes)
718-
return self._switch_transpose_and_node(node, trans)
719-
else: # in opset 10, axes is input instead of an attribute.
720-
if len(node.inputs) >= 4 and node.inputs[3].is_const():
721-
axes = node.inputs[3].get_tensor_value(as_list=True)
722-
if axes == list(range(trans_rank)):
723-
axes = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
724-
# axes node might be shared
725-
new_axes = np.array(axes, dtype=np.int64)
726-
if self._nodes_has_single_consumer_node([node.inputs[3]]):
727-
node.inputs[3].set_tensor_value(new_axes)
728-
else:
729-
new_axes_const = self._g.make_const(
730-
utils.make_name(node.inputs[3].name), new_axes
731-
)
732-
self._g.replace_input(node, node.input[3], new_axes_const.output[0], 3)
733-
return self._switch_transpose_and_node(node, trans)
715+
perm = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
716+
new_axes = [perm[axes[i]] for i in range(len(axes))]
717+
node.set_attr("axes", new_axes)
718+
return self._switch_transpose_and_node(node, trans)
719+
# in opset 10, axes is input instead of an attribute.
720+
if len(node.inputs) >= 4 and node.inputs[3].is_const():
721+
axes = node.inputs[3].get_tensor_value(as_list=False)
722+
dtype = axes.dtype
723+
axes = axes.tolist()
724+
perm = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
725+
axes = [perm[axes[i]] for i in range(len(axes))]
726+
# axes node might be shared
727+
new_axes = np.array(axes, dtype=dtype)
728+
if self._nodes_has_single_consumer_node([node.inputs[3]]):
729+
node.inputs[3].set_tensor_value(new_axes)
730+
else:
731+
new_axes_const = self._g.make_const(
732+
utils.make_name(node.inputs[3].name), new_axes
733+
)
734+
self._g.replace_input(node, node.input[3], new_axes_const.output[0], 3)
735+
return self._switch_transpose_and_node(node, trans)
734736
return False
735737

736738
def _quantize_handler(self, trans, node):

0 commit comments

Comments
 (0)