Skip to content

Commit ade6a0a

Browse files
Add argmax and reduction ops to transpose optimizer (#1383)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 257b45f commit ade6a0a

File tree

2 files changed

+94
-23
lines changed

2 files changed

+94
-23
lines changed

tests/test_optimizers.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,51 @@ def test_transpose_reducesum(self, input_shape, output_shape, axes, perm_input,
11301130
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
11311131
model_proto, remaining_transpose_num=0)
11321132

1133+
@parameterized.expand([
1134+
((1, 3, 4, 5), (1, 3, 4), [2], [0, 2, 3, 1], [0, 2, 1]),
1135+
((1, 3, 4, 5), (1, 3), [1, 2], [0, 2, 3, 1], [0, 1]),
1136+
((1, 3, 4, 5), (), [0, 1, 2, 3], [0, 2, 3, 1], []),
1137+
((1, 3, 4, 5, 6), (1, 3, 5, 6), [1], [0, 2, 3, 4, 1], [0, 3, 1, 2]),
1138+
((1, 3, 4, 5, 6), (1, 3), [1, 2, 3], [0, 2, 3, 4, 1], [0, 1]),
1139+
((1, 3, 4, 5, 6), (), [0, 1, 2, 3, 4], [0, 2, 3, 4, 1], []),
1140+
])
1141+
def test_transpose_reducemax(self, input_shape, output_shape, axes, perm_input, perm_output):
1142+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
1143+
node1 = helper.make_node("ReduceMax", ["Y"], ["Z"], axes=axes,
1144+
keepdims=0, name="reducemax")
1145+
if perm_output:
1146+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=perm_output, name="trans_2")
1147+
else:
1148+
node2 = helper.make_node("Identity", ["Z"], ["res"], name="trans_2")
1149+
1150+
graph = helper.make_graph(
1151+
[node0, node1, node2],
1152+
"transpose-reducemax-test",
1153+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
1154+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, output_shape)],
1155+
)
1156+
1157+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1158+
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
1159+
model_proto, remaining_transpose_num=0)
1160+
1161+
def test_transpose_argmax(self):
1162+
input_shape = [1, 2, 3, 4]
1163+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
1164+
node1 = helper.make_node("ArgMax", ["Y"], ["Z"], axis=3, keepdims=0, name="argmax")
1165+
node2 = helper.make_node("Cast", ["Z"], ["res"], to=TensorProto.INT32, name="cast")
1166+
1167+
graph = helper.make_graph(
1168+
[node0, node1, node2],
1169+
"transpose-argmax-test",
1170+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
1171+
[helper.make_tensor_value_info("res", TensorProto.INT32, [1, 3, 4])],
1172+
)
1173+
1174+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1175+
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
1176+
model_proto, remaining_transpose_num=0)
1177+
11331178
@parameterized.expand([
11341179
((1, 3, 4, 5), (1, 3, 4, 1), [2], [0, 2, 3, 1], [0, 3, 1, 2]),
11351180
((1, 3, 4, 5), (1, 3, 1, 1), [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def nodes(self):
5555

5656
def pre_optimize_action(self):
5757
# make Reshape into a const, which then can be fused into Conv's weight for mobilenet_v1_75_192
58-
self._output_names = [name.split(":")[0] for name in self._g.outputs]
58+
self._output_names = [self._g.get_node_by_output(out).name for out in self._g.outputs]
5959
ops = self.nodes
6060
constable_reshape_ops = [n for n in ops
6161
if (n.type == "Reshape"
@@ -179,6 +179,8 @@ def _optimize_at_current_graph_level(self, graph):
179179
def _initialize_handlers(self):
180180
self._handler_map = {
181181
"Add": self._add_handler,
182+
"ArgMax": self._arg_min_max_handler,
183+
"ArgMin": self._arg_min_max_handler,
182184
"Cast": self._simple_through_handler,
183185
"Clip": self._simple_through_handler,
184186
"Concat": self._concat_handler,
@@ -192,8 +194,14 @@ def _initialize_handlers(self):
192194
"Mul": self._mul_handler,
193195
"Pad": self._pad_handler,
194196
"Reciprocal": self._simple_through_handler,
195-
"ReduceMean": self._reducemean_handler,
197+
"ReduceLogSum": self._reduce_handler,
198+
"ReduceLogSumExp": self._reduce_handler,
199+
"ReduceMax": self._reduce_handler,
200+
"ReduceMean": self._reduce_handler,
201+
"ReduceMin": self._reduce_handler,
202+
"ReduceProd": self._reduce_handler,
196203
"ReduceSum": self._reducesum_handler,
204+
"ReduceSumSquare": self._reduce_handler,
197205
"Relu": self._simple_through_handler,
198206
"Shape": self._shape_handler,
199207
"Sigmoid": self._simple_through_handler,
@@ -258,7 +266,7 @@ def _get_input_index_for_trans(self, node, trans):
258266
return input_index
259267

260268
# the assumption is: both node and trans have only 1 output
261-
def _switch_transpose_and_node(self, node, trans):
269+
def _switch_transpose_and_node(self, node, trans, update_shape=True):
262270
if not self._nodes_has_single_consumer_node([trans]):
263271
return False
264272

@@ -271,7 +279,7 @@ def _switch_transpose_and_node(self, node, trans):
271279
# need to transpose node shape in backward direction as well after switch
272280
# otherwise, reshape added in post_optimize_action may not work correctly
273281
shape = self._g.get_shape(node.output[0])
274-
if shape:
282+
if update_shape and shape:
275283
# only nhwc transpose can reach here
276284
new_shape = [shape[i] for i in NHWC_TO_NCHW]
277285
self._g.set_shape(node.output[0], new_shape)
@@ -700,31 +708,49 @@ def _pad_handler(self, trans, node):
700708
self._g.replace_input(node, node.input[1], new_pads.output[0], 1)
701709
return self._switch_transpose_and_node(node, trans)
702710

703-
def _reducemean_handler(self, trans, node):
704-
axes = node.get_attr("axes").ints
705-
keepdims = node.get_attr("keepdims")
711+
def _arg_min_max_handler(self, trans, node):
712+
axis = node.get_attr_value("axis", 0)
713+
node.set_attr("axes", [axis])
714+
result = self._reduce_handler(trans, node)
715+
new_axis = node.get_attr_value("axes")[0]
716+
node.set_attr("axis", new_axis)
717+
del node.attr["axes"]
718+
return result
719+
720+
def _reduce_handler(self, trans, node):
721+
keepdims = node.get_attr_value("keepdims", 1)
706722
trans_rank = get_transpose_rank(trans)
707-
# make sure keepdims is 1, then we can do the swap, otherwise, please don't, because
708-
# once keepdims is not set, original dims are lost, so transpose back won't work well.
709-
# by default, if keepdims is not specified, it is 1
710-
if axes == list(range(1, trans_rank - 1)) and ((keepdims and keepdims.i == 1) or (not keepdims)):
711-
node.set_attr("axes", list(range(2, trans_rank)))
712-
return self._switch_transpose_and_node(node, trans)
713-
return False
723+
axes = node.get_attr_value("axes", list(range(trans_rank)))
724+
perm = trans.get_attr("perm").ints
725+
axes = [a + trans_rank if a < 0 else a for a in axes]
726+
new_axes = [perm[a] for a in axes]
727+
update_shape = keepdims == 1
728+
shape = self._g.get_shape(node.output[0])
729+
if not self._switch_transpose_and_node(node, trans, update_shape):
730+
return False
731+
node.set_attr("axes", new_axes)
732+
if keepdims == 0:
733+
remaining_axes = []
734+
j = 0
735+
for i in range(trans_rank):
736+
if i in new_axes:
737+
remaining_axes.append(None)
738+
else:
739+
remaining_axes.append(j)
740+
j += 1
741+
new_perm = [remaining_axes[p] for p in perm if remaining_axes[p] is not None]
742+
if shape:
743+
new_shape = [shape[new_perm.index(i)] for i in range(len(new_perm))]
744+
self._g.set_shape(node.output[0], new_shape)
745+
trans.set_attr("perm", new_perm)
746+
return True
714747

715748
def _reducesum_handler(self, trans, node):
716749
keepdims = node.get_attr("keepdims")
717-
# make sure keepdims is 1, then we can do the swap, otherwise, please don't, because
718-
# once keepdims is not set, original dims are lost, so transpose back won't work well.
719-
# by default, if keepdims is not specified, it is 1
750+
if self._g.opset <= 12:
751+
return self._reduce_handler(trans, node)
720752
if keepdims and keepdims.i == 0:
721753
return False
722-
if self._g.opset <= 12:
723-
axes = node.get_attr("axes").ints
724-
perm = trans.get_attr('perm').ints
725-
new_axes = [perm[axis] for axis in axes]
726-
node.set_attr("axes", new_axes)
727-
return self._switch_transpose_and_node(node, trans)
728754
if node.inputs[1].is_const():
729755
axes = node.inputs[1].get_tensor_value()
730756
perm = trans.get_attr('perm').ints

0 commit comments

Comments
 (0)