Skip to content

Commit d2930f8

Browse files
committed
support transpose opt with add when input shape is None
1 parent 2dcfc75 commit d2930f8

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

tests/test_optimizers.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,26 +90,47 @@ def test_transpose_with_concat(self):
9090
}
9191
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)
9292

93-
def test_transpose_with_add(self):
93+
def test_transpose_with_add1(self):
9494
# when transpose follows with a broadcasting op
9595
# reshape is needed when switching transpose with this op and op need broadcast its inputs
9696
node1 = helper.make_node("Transpose", ["input_data1"], ["Y"], perm=[0, 2, 3, 1], name="trans")
9797
node2 = helper.make_node("Add", ["Y", "input_data2"], ["Z"], name="add")
98+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans2")
9899

99100
graph = helper.make_graph(
100-
[node1, node2],
101+
[node1, node2, node3],
101102
"transpose_with_shape",
102103
[helper.make_tensor_value_info("input_data1", TensorProto.FLOAT, (2, 3, 4, 5)),
103104
helper.make_tensor_value_info("input_data2", TensorProto.FLOAT, (3,)),
104105
],
105-
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, [2, 4, 5, 3])],
106+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (2, 3, 4, 5))],
106107
)
107108

108109
model_proto = helper.make_model(graph, producer_name="onnx-tests")
109110
feed_dict = {"input_data1": np.random.randn(2, 3, 4, 5).astype(np.float32),
110111
"input_data2": np.random.randn(3).astype(np.float32),
111112
}
112-
self.run_transpose_compare(["Z"], feed_dict, model_proto, remaining_transpose_num=1)
113+
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=0)
114+
115+
def test_transpose_with_add2(self):
116+
node1 = helper.make_node("Transpose", ["input_data1"], ["Y"], perm=[0, 2, 3, 1], name="trans")
117+
node2 = helper.make_node("Add", ["Y", "input_data2"], ["Z"], name="add")
118+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans2")
119+
120+
graph = helper.make_graph(
121+
[node1, node2, node3],
122+
"transpose_with_shape",
123+
[helper.make_tensor_value_info("input_data1", TensorProto.FLOAT, (2, 3, 4, 5)),
124+
helper.make_tensor_value_info("input_data2", TensorProto.FLOAT, (2, 4, 5, 3)),
125+
],
126+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (2, 3, 4, 5))],
127+
)
128+
129+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
130+
feed_dict = {"input_data1": np.random.randn(2, 3, 4, 5).astype(np.float32),
131+
"input_data2": np.random.randn(2, 4, 5, 3).astype(np.float32),
132+
}
133+
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)
113134

114135
def test_transpose_relu(self):
115136
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010

1111
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW
12+
import onnx
1213
from .. import utils
1314
from .optimizer_base import GraphOptimizerBase
1415

@@ -308,7 +309,8 @@ def _create_transpose_pairs_after_node(self, node):
308309

309310
def _create_transpose_pairs_before_node(self, node):
310311
def shape_after_expand(ori_shape):
311-
# according to broadcasting rule to expand shape to 4D
312+
# according to broadcasting rule to expand shape to 4D while not tile the tensor here
313+
# still count on the broadcasting op to tile the tensor
312314
utils.make_sure(ori_shape.count(-1) <= 1, "shape can contain one -1 at most")
313315
ori_rank = len(ori_shape)
314316
utils.make_sure(ori_rank <= 4, "ONNX only supports 4D data")
@@ -324,16 +326,28 @@ def shape_after_expand(ori_shape):
324326

325327
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nhwc_trans_consumers
326328
for input_id, n in non_nhwc_trans_inputs:
327-
shape = self._g.get_shape(n.output[0])
328-
if len(shape) == 4:
329-
nchw_node = self._g.make_node("Transpose", [input_id], attr={"perm": [0, 3, 1, 2]})
329+
shape = self._g.get_shape(input_id)
330+
# if rank of n is not 4, then we need to insert a reshape op before inserting a transpose
331+
# for example shape of n is [x, y], then output shape of reshape will be [1, 1, x, y]
332+
if shape is None:
333+
const_4 = self._g.make_const(utils.make_name("const_4"), np.array([4], np.int64)).output[0]
334+
tensor_1 = onnx.helper.make_tensor("value", onnx.TensorProto.INT64, [1], [1])
335+
shape_node = self._g.make_node("Shape", [input_id]).output[0]
336+
rank_node = self._g.make_node("Shape", [shape_node]).output[0]
337+
expand_rank = self._g.make_node("Sub", [const_4, rank_node]).output[0]
338+
array_fill_1 = self._g.make_node("ConstantOfShape", [expand_rank], attr={"value": tensor_1}).output[0]
339+
new_shape = self._g.make_node("Concat", [array_fill_1, shape_node], attr={"axis": 0}).output[0]
340+
reshape = self._g.make_node("Reshape", [input_id, new_shape]).output[0]
341+
input_of_new_trans = reshape
342+
elif len(shape) == 4:
343+
input_of_new_trans = input_id
330344
else:
331345
shape_4d = shape_after_expand(shape)
332-
shape_const = self._g.make_const(utils.make_name("reshape_shape"),
333-
np_val=np.array(shape_4d, np.int64)).output[0]
334-
reshape = self._g.make_node("Reshape", [input_id, shape_const]).output[0]
335-
nchw_node = self._g.make_node("Transpose", [reshape], attr={"perm": [0, 3, 1, 2]})
346+
const = self._g.make_const(utils.make_name("reshape_shape"), np.array(shape_4d, np.int64)).output[0]
347+
reshape = self._g.make_node("Reshape", [input_id, const]).output[0]
348+
input_of_new_trans = reshape
336349

350+
nchw_node = self._g.make_node("Transpose", [input_of_new_trans], attr={"perm": [0, 3, 1, 2]})
337351
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": [0, 2, 3, 1]})
338352
self._g.replace_input(node, input_id, nhwc_node.output[0])
339353

0 commit comments

Comments
 (0)