Skip to content

Commit e896723

Browse files
authored
Fix transpose split optimize attr when opset >=13 (#1996)
* add tranposeOptimizer split for opset >= 13 Signed-off-by: Deyu Huang <[email protected]> * add test Signed-off-by: Deyu Huang <[email protected]> * fix comments Signed-off-by: Deyu Huang <[email protected]>
1 parent 9ce72be commit e896723

File tree

2 files changed

+48
-15
lines changed

2 files changed

+48
-15
lines changed

tests/test_optimizers.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def test_transpose_with_split(self, input_shape, perm, inner_perm):
145145
((1, -1), (1, 1710), (1710,), [1, 0]),
146146
((3, 1, 1, 5, -1), (3, 1, 1, 5, 6), (3, 5, 6), [0, 2, 3, 4, 1]),
147147
])
148-
@check_opset_max_version(12, "split attribute changed to input in opset 13")
148+
@check_opset_max_version(12, "split attribute changed to input since opset 13")
149149
def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, output_shape, perm):
150150
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
151151
node2 = helper.make_node("Split", ["Y"], ["Z"], axis=1, split=[1], name="split")
@@ -162,6 +162,31 @@ def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, o
162162
self.run_transpose_compare(["B"], {"X": np.random.randn(*specific_input).astype(np.float32)},
163163
model_proto, remaining_transpose_num=0)
164164

165+
@parameterized.expand([
166+
((3, 1, 1), (1, 1, 3), (1), [0, 2, 3, 1]),
167+
((256, 1, 1), (1, 1, 256), (1), [0, 2, 3, 1])
168+
])
169+
@check_opset_min_version(13, "split attribute changed to input since opset 13")
170+
def test_transpose_with_split_opset13(self, input_shape, output_shape, split_val, perm):
171+
unsqueeze_axes = self._make_onnx_const(np.array([0], dtype=np.int64), "axes1")
172+
unsqueeze = helper.make_node("Unsqueeze", ["X", "axes1"], ["Y"], name="unsqueeze")
173+
trans = helper.make_node("Transpose", ["Y"], ["Z"], perm=perm, name="trans")
174+
split_attr = self._make_onnx_const(np.array([split_val], dtype=np.int64), "split_attr")
175+
split = helper.make_node("Split", ["Z", "split_attr"], ["A"], axis=0, name="split")
176+
squeeze_axes = self._make_onnx_const(np.array([1], dtype=np.int64), "axes2")
177+
squeeze = helper.make_node("Squeeze", ["A", "axes2"], ["B"], name="squeeze")
178+
179+
graph = helper.make_graph(
180+
[unsqueeze_axes, unsqueeze, trans, split_attr, split, squeeze_axes, squeeze],
181+
"test_transpose_with_split_opset13",
182+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
183+
[helper.make_tensor_value_info("B", TensorProto.FLOAT, output_shape)],
184+
)
185+
186+
model_proto = self.make_model(graph, producer_name="onnx-tests")
187+
self.run_transpose_compare(["B"], {"X": np.random.randn(*input_shape).astype(np.float32)},
188+
model_proto, remaining_transpose_num=0)
189+
165190
@parameterized.expand([
166191
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
167192
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
@@ -717,7 +742,7 @@ def test_transpose_sqrt(self, shape, perm_input, perm_output):
717742
((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]),
718743
((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]),
719744
])
720-
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
745+
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
721746
def test_transpose_with_squeeze1(self, input_shape, output_shape, perm, expected_perm):
722747
# squeeze the first dim
723748
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
@@ -768,7 +793,7 @@ def test_transpose_with_unsqueeze(self, input_shape, output_shape, perm, axes_va
768793
((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]),
769794
((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]),
770795
])
771-
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
796+
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
772797
def test_transpose_with_squeeze1_13(self, input_shape, output_shape, perm, expected_perm):
773798
# squeeze the first dim
774799
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
@@ -791,7 +816,7 @@ def test_transpose_with_squeeze1_13(self, input_shape, output_shape, perm, expec
791816
((3, 4, 1, 5), (3, 5, 4), [0, 2, 3, 1], [0, 2, 1]),
792817
((3, 4, 1, 5, 6), (3, 5, 6, 4), [0, 2, 3, 4, 1], [0, 2, 3, 1]),
793818
])
794-
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
819+
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
795820
def test_transpose_with_squeeze2(self, input_shape, output_shape, perm, expected_perm):
796821
# squeeze the second dim
797822
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
@@ -813,7 +838,7 @@ def test_transpose_with_squeeze2(self, input_shape, output_shape, perm, expected
813838
((3, 4, 1, 5), (3, 5, 4), [0, 2, 3, 1], [0, 2, 1]),
814839
((3, 4, 1, 5, 6), (3, 5, 6, 4), [0, 2, 3, 4, 1], [0, 2, 3, 1]),
815840
])
816-
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
841+
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
817842
def test_transpose_with_squeeze2_13(self, input_shape, output_shape, perm, expected_perm):
818843
# squeeze the second dim
819844
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
@@ -836,7 +861,7 @@ def test_transpose_with_squeeze2_13(self, input_shape, output_shape, perm, expec
836861
((3, 1, 4, 5), (3, 4, 5), [0, 2, 3, 1]),
837862
((3, 1, 4, 5, 6), (3, 4, 5, 6), [0, 2, 3, 4, 1]),
838863
])
839-
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
864+
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
840865
def test_transpose_with_squeeze3(self, input_shape, output_shape, perm):
841866
# squeeze the last dim
842867
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
@@ -857,7 +882,7 @@ def test_transpose_with_squeeze3(self, input_shape, output_shape, perm):
857882
((3, 1, 4, 5), (3, 4, 5), [0, 2, 3, 1]),
858883
((3, 1, 4, 5, 6), (3, 4, 5, 6), [0, 2, 3, 4, 1]),
859884
])
860-
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
885+
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
861886
def test_transpose_with_squeeze3_13(self, input_shape, output_shape, perm):
862887
# squeeze the last dim
863888
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
@@ -879,7 +904,7 @@ def test_transpose_with_squeeze3_13(self, input_shape, output_shape, perm):
879904
((3, 1, 1, 5), (3, 5), [0, 2, 3, 1]),
880905
((3, 1, 1, 5, 4), (3, 5, 4), [0, 2, 3, 4, 1]),
881906
])
882-
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
907+
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
883908
def test_transpose_with_squeeze4(self, input_shape, output_shape, perm):
884909
# squeeze the two dims
885910
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
@@ -900,7 +925,7 @@ def test_transpose_with_squeeze4(self, input_shape, output_shape, perm):
900925
((3, 1, 1, 5), (3, 5), [0, 2, 3, 1]),
901926
((3, 1, 1, 5, 4), (3, 5, 4), [0, 2, 3, 4, 1]),
902927
])
903-
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
928+
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
904929
def test_transpose_with_squeeze4_13(self, input_shape, output_shape, perm):
905930
# squeeze the two dims
906931
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
@@ -2156,7 +2181,7 @@ def test_const_fold_concat(self):
21562181
self.run_and_compare(["res"], {"inp": np.random.randn(6, 12).astype(np.float32)}, model_proto,
21572182
"Concat", 0)
21582183

2159-
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
2184+
@check_opset_max_version(12, "Squeeze/Unsqueeze changed since opset 13")
21602185
def test_const_fold_unsqueeze_with_const(self):
21612186
shape = (6, 6)
21622187
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
@@ -2176,7 +2201,7 @@ def test_const_fold_unsqueeze_with_const(self):
21762201
self.run_and_compare(["res"], {"X": np.random.randn(1).astype(np.float32)}, model_proto,
21772202
"Unsqueeze", 0)
21782203

2179-
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
2204+
@check_opset_min_version(13, "Squeeze/Unsqueeze changed since opset 13")
21802205
def test_const_fold_unsqueeze_with_const_13(self):
21812206
shape = (6, 6)
21822207
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
@@ -2254,7 +2279,7 @@ def test_const_fold_split_one(self):
22542279
self.run_and_compare(["out4"], {"inp": np.random.randn(2, 6, 1).astype(np.float32)}, model_proto,
22552280
"Split", 0)
22562281

2257-
@check_opset_min_version(13, "Split changed in opset 13")
2282+
@check_opset_min_version(13, "Split changed since opset 13")
22582283
def test_const_fold_split_const_splits_13(self):
22592284
shape = (2, 6, 1)
22602285
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
@@ -2277,7 +2302,7 @@ def test_const_fold_split_const_splits_13(self):
22772302
self.run_and_compare(["out4"], {"inp": np.random.randn(2, 3, 1).astype(np.float32)}, model_proto,
22782303
"Split", 0)
22792304

2280-
@check_opset_max_version(12, "Split changed in opset 13")
2305+
@check_opset_max_version(12, "Split changed since opset 13")
22812306
def test_const_fold_split_const_splits(self):
22822307
shape = (2, 6, 1)
22832308
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,11 +671,19 @@ def _concat_handler(self, trans, node):
671671

672672
def _split_handler(self, trans, node):
673673
# Todo: need handle cases where Split node has more than 1 outputs.
674+
split = None
675+
if self._g.opset >= 13 and len(node.input) > 1 and node.inputs[1].is_const():
676+
# split is an input not attr since opset 13
677+
split = node.inputs[1].get_tensor_value(as_list=True)
674678
if self._handle_node_having_branches(trans, node):
675679
perm = trans.get_attr_value("perm")
676680
axis = node.get_attr_value("axis", 0)
677681
new_axis = perm[axis]
678682
node.set_attr("axis", new_axis)
683+
if split:
684+
new_axes_np = np.array(split, dtype=np.int64)
685+
new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np)
686+
self._g.replace_inputs(node, [node.input[0], new_axes_const.output[0]])
679687
return True
680688
return False
681689

@@ -747,7 +755,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
747755
shape_after_trans = [input_shape[i] for i in ori_perm]
748756
output_shape = [shape_after_trans[i] for i in range(n) if i not in ori_squeeze_axes]
749757
# calculate new_perm
750-
# after switch, the output shape should be same, using this condtion we can figure the new perm
758+
# after switch, the output shape should be same, using this condition we can figure the new perm
751759
shape_after_squeeze = [input_shape[i] for i in range(n) if i not in new_squeeze_axes]
752760
new_perm = [shape_after_squeeze.index(i) for i in output_shape]
753761

@@ -757,7 +765,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
757765
return False
758766

759767
axes = None
760-
# in opset 13, axes is an input not attr
768+
# axes is an input not attr since opset 13
761769
if node.get_attr("axes"):
762770
axes = node.get_attr("axes").ints
763771
if len(node.input) > 1 and node.inputs[1].is_const():

0 commit comments

Comments
 (0)