Skip to content

Commit 78172d1

Browse files
Handle ReduceSum in TransposeOptimizer (#1342)
* Handle ReduceSum in TransposeOptimizer Signed-off-by: Mateusz Tabaka <[email protected]> * Remove unused variable Signed-off-by: Mateusz Tabaka <[email protected]>
1 parent 4224f40 commit 78172d1

File tree

2 files changed

+83
-1
lines changed

2 files changed

+83
-1
lines changed

tests/test_optimizers.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,6 @@ def test_transpose_reciprocal(self, shape, perm_input, perm_output):
10621062
self.run_transpose_compare(["OUT"], {"X": np.random.randn(*shape).astype(np.float32)},
10631063
model_proto, remaining_transpose_num=0)
10641064

1065-
10661065
@parameterized.expand([
10671066
((1, 3, 4, 5), (1, 3, 1, 1), [0, 2, 3, 1], [0, 3, 1, 2]),
10681067
((1, 3, 4, 5, 6), (1, 3, 1, 1, 1), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
@@ -1084,6 +1083,60 @@ def test_transpose_reducemean(self, input_shape, output_shape, perm_input, perm_
10841083
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
10851084
model_proto, remaining_transpose_num=0)
10861085

1086+
@parameterized.expand([
1087+
((1, 3, 4, 5), (1, 3, 4, 1), [2], [0, 2, 3, 1], [0, 3, 1, 2]),
1088+
((1, 3, 4, 5), (1, 3, 1, 1), [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
1089+
((1, 3, 4, 5), (1, 1, 1, 1), [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
1090+
((1, 3, 4, 5, 6), (1, 3, 1, 5, 6), [1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1091+
((1, 3, 4, 5, 6), (1, 3, 1, 1, 1), [1, 2, 3], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1092+
((1, 3, 4, 5, 6), (1, 1, 1, 1, 1), [0, 1, 2, 3, 4], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1093+
])
1094+
@check_opset_max_version(12, "ReduceSum from opset <= 12 has axes as an attribute")
1095+
def test_transpose_reducesum(self, input_shape, output_shape, axes, perm_input, perm_output):
1096+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
1097+
node1 = helper.make_node("ReduceSum", ["Y"], ["Z"], axes=axes,
1098+
keepdims=1, name="reducesum")
1099+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=perm_output, name="trans_2")
1100+
1101+
graph = helper.make_graph(
1102+
[node0, node1, node2],
1103+
"transpose-reducesum-test",
1104+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
1105+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, output_shape)],
1106+
)
1107+
1108+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1109+
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
1110+
model_proto, remaining_transpose_num=0)
1111+
1112+
@parameterized.expand([
1113+
((1, 3, 4, 5), (1, 3, 4, 1), [2], [0, 2, 3, 1], [0, 3, 1, 2]),
1114+
((1, 3, 4, 5), (1, 3, 1, 1), [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
1115+
((1, 3, 4, 5), (1, 1, 1, 1), [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
1116+
((1, 3, 4, 5, 6), (1, 3, 1, 5, 6), [1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1117+
((1, 3, 4, 5, 6), (1, 3, 1, 1, 1), [1, 2, 3], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1118+
((1, 3, 4, 5, 6), (1, 1, 1, 1, 1), [0, 1, 2, 3, 4], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
1119+
])
1120+
@check_opset_min_version(13, "ReduceSum from opset >= 13 has axes as an input")
1121+
def test_transpose_reducesum_opset_13(self, input_shape, output_shape, axes, perm_input, perm_output):
1122+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
1123+
node1 = helper.make_node("ReduceSum", ["Y", "axes"], ["Z"], keepdims=1, name="reducesum")
1124+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=perm_output, name="trans_2")
1125+
1126+
axes = np.array(axes, dtype=np.int64)
1127+
1128+
graph = helper.make_graph(
1129+
[node0, node1, node2],
1130+
"transpose-reducesum-test",
1131+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
1132+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, output_shape)],
1133+
[helper.make_tensor("axes", TensorProto.INT64, axes.shape, axes)],
1134+
)
1135+
1136+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1137+
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
1138+
model_proto, remaining_transpose_num=0)
1139+
10871140
@parameterized.expand([
10881141
((2, 3, 4, 5), (2, 4, 5, 3), [0, 2, 3, 1]),
10891142
((2, 3, 4, 5, 6), (2, 4, 5, 6, 3), [0, 2, 3, 4, 1]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def _initialize_handlers(self):
193193
"Pad": self._pad_handler,
194194
"Reciprocal": self._simple_through_handler,
195195
"ReduceMean": self._reducemean_handler,
196+
"ReduceSum": self._reducesum_handler,
196197
"Relu": self._simple_through_handler,
197198
"Shape": self._shape_handler,
198199
"Sigmoid": self._simple_through_handler,
@@ -712,6 +713,34 @@ def _reducemean_handler(self, trans, node):
712713
return self._switch_transpose_and_node(node, trans)
713714
return False
714715

716+
def _reducesum_handler(self, trans, node):
717+
keepdims = node.get_attr("keepdims")
718+
# make sure keepdims is 1, then we can do the swap, otherwise, please don't, because
719+
# once keepdims is not set, original dims are lost, so transpose back won't work well.
720+
# by default, if keepdims is not specified, it is 1
721+
if keepdims and keepdims.i == 0:
722+
return False
723+
if self._g.opset <= 12:
724+
axes = node.get_attr("axes").ints
725+
perm = trans.get_attr('perm').ints
726+
new_axes = [perm[axis] for axis in axes]
727+
node.set_attr("axes", new_axes)
728+
return self._switch_transpose_and_node(node, trans)
729+
if node.inputs[1].is_const():
730+
axes = node.inputs[1].get_tensor_value()
731+
perm = trans.get_attr('perm').ints
732+
axes = [perm[axes[i]] for i in range(len(axes))]
733+
new_axes = np.array(axes, dtype=np.int64)
734+
if self._nodes_has_single_consumer_node([node.inputs[1]]):
735+
node.inputs[1].set_tensor_value(new_axes)
736+
else:
737+
new_axes_const = self._g.make_const(
738+
utils.make_name(node.inputs[1].name), new_axes
739+
)
740+
self._g.replace_input(node, node.input[1], new_axes_const.output[0], 1)
741+
return self._switch_transpose_and_node(node, trans)
742+
return False
743+
715744
def _slice_handler(self, trans, node):
716745
trans_rank = get_transpose_rank(trans)
717746
axes = None

0 commit comments

Comments
 (0)