Skip to content

Commit e3bb930

Browse files
Optimize Transpose->Pad when 'pads' input is not a constant (#1331)
When the second input is not a constant, let's shuffle it with Split followed by Concat. There are examples of models, where this non-constant input gets constant folded anyway by a framework. Even if that's not the case. Split+Concat of 8 (or 10) element tensor should be a good trade for Transpose pair. Signed-off-by: Mateusz Tabaka <[email protected]>
1 parent 8aa1127 commit e3bb930

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

tests/test_optimizers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,34 @@ def test_transpose_pad11(self, input_shape, output_shape, pads, perm_input, perm
922922
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
923923
model_proto, remaining_transpose_num=0)
924924

925+
@parameterized.expand([
926+
((1, 3, 4, 5), (2, 6, 4, 8), [1, 0, 1, 3, 0, 0, 2, 0], [0, 2, 3, 1], [0, 3, 1, 2]),
927+
((1, 3, 4, 5, 6), (2, 5, 6, 8, 10), [1, 0, 1, 3, 1, 0, 2, 2, 1, 1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
928+
])
929+
@check_opset_min_version(11, "pad")
930+
def test_transpose_pad11_non_const_pads(self, input_shape, output_shape, pads, perm_input, perm_output):
931+
932+
pads_val = np.array(pads, dtype=np.int64)
933+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
934+
node1 = helper.make_node("Pad", ["Y", "Pads"], ["Z"], name="pad")
935+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=perm_output, name="trans_2")
936+
937+
graph = helper.make_graph(
938+
[node0, node1, node2],
939+
"transpose-pad-test",
940+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape),
941+
helper.make_tensor_value_info("Pads", TensorProto.INT64, pads_val.shape)],
942+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, output_shape)],
943+
)
944+
945+
model_proto = self.make_model(graph, producer_name="onnx-tests")
946+
self.run_transpose_compare(["res"],
947+
{
948+
"X": np.random.randn(*input_shape).astype(np.float32),
949+
"Pads": pads_val
950+
},
951+
model_proto, remaining_transpose_num=0)
952+
925953
@parameterized.expand([
926954
((1, 3, 4, 5), (1, 3, 1, 1), [0, 2, 3, 1], [0, 3, 1, 2]),
927955
((1, 3, 4, 5, 6), (1, 3, 1, 1, 1), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,21 @@ def _pad_handler(self, trans, node):
676676
input1.set_tensor_value(new_pads)
677677
input1.data_format = "NCHW"
678678
return self._switch_transpose_and_node(node, trans)
679-
return False
679+
# when the second input is not a constant, let's shuffle it with Split followed by Concat
680+
# there are examples of models, where this non-constant input
681+
# gets constant folded anyway by a framework.
682+
split = self._g.make_node("Split", inputs=[node.input[1]], attr={}, output_count=trans_rank * 2)
683+
pads = split.output
684+
if trans_rank == 4:
685+
new_pads = self._g.make_node("Concat", [pads[0], pads[3], pads[1], pads[2],
686+
pads[4], pads[7], pads[5], pads[6]],
687+
{'axis': 0})
688+
else:
689+
new_pads = self._g.make_node("Concat", [pads[0], pads[4], pads[1], pads[2], pads[3],
690+
pads[5], pads[9], pads[6], pads[7], pads[8]],
691+
{'axis': 0})
692+
self._g.replace_input(node, node.input[1], new_pads.output[0], 1)
693+
return self._switch_transpose_and_node(node, trans)
680694

681695
def _reducemean_handler(self, trans, node):
682696
axes = node.get_attr("axes").ints

0 commit comments

Comments
 (0)