Skip to content

Commit fe0614f

Browse files
authored
Merge pull request #607 from zhijxu-MS/refactor
refatcor and fix bug
2 parents a415293 + f214a07 commit fe0614f

File tree

3 files changed

+42
-11
lines changed

3 files changed

+42
-11
lines changed

tests/test_optimizers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,33 @@ def check_transpose_perm(self, model_proto, expected_perm):
6363
perm = list(node.attribute[0].ints)
6464
self.assertEqual(perm, expected_perm)
6565

66+
def test_transpose_with_concat(self):
67+
input_shape = (2, 3, 4, 5)
68+
perm = [0, 3, 1, 2]
69+
input_shape_with_trans = [input_shape[i] for i in perm]
70+
for axis in [0, 1, 2, 3]:
71+
output_before_trans = list(input_shape)
72+
output_before_trans[axis] *= 2
73+
output_shape = [output_before_trans[i] for i in [0, 3, 1, 2]]
74+
node1 = helper.make_node("Transpose", ["input_data1"], ["Y"], perm=[0, 2, 3, 1], name="trans")
75+
node2 = helper.make_node("Concat", ["Y", "input_data2"], ["Z"], axis=axis, name="concat")
76+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans2")
77+
78+
graph = helper.make_graph(
79+
[node1, node2, node3],
80+
"test_transpose_with_concat",
81+
[helper.make_tensor_value_info("input_data1", TensorProto.FLOAT, input_shape_with_trans),
82+
helper.make_tensor_value_info("input_data2", TensorProto.FLOAT, input_shape),
83+
],
84+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, output_shape)],
85+
)
86+
87+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
88+
feed_dict = {"input_data1": np.random.randn(*input_shape_with_trans).astype(np.float32),
89+
"input_data2": np.random.randn(*input_shape).astype(np.float32),
90+
}
91+
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)
92+
6693
def test_transpose_relu(self):
6794
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
6895
node2 = helper.make_node("Relu", ["Y"], ["Z"], name="relu")

tf2onnx/onnx_opset/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,13 +538,13 @@ def _convert_since_9(cls, ctx, node, op_type):
538538
# scales is nchw
539539
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
540540
# because onnxruntime only supports to scale the last two dims so transpose is inserted
541-
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": [0, 3, 1, 2]})
541+
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
542542
upsample = ctx.make_node(op_type, [input_nchw.output[0], scales.output[0]], attr={"mode": mode})
543543

544544
shapes = node.output_shapes
545545
dtypes = node.output_dtypes
546546
ctx.remove_node(node.name)
547-
ctx.make_node("Transpose", upsample.output, {"perm": [0, 2, 3, 1]},
547+
ctx.make_node("Transpose", upsample.output, {"perm": constants.NCHW_TO_NHWC},
548548
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
549549

550550

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010

11+
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW
1112
from .. import utils
1213
from .optimizer_base import GraphOptimizerBase
1314

@@ -18,12 +19,12 @@
1819

1920
def is_nhwc_transpose(transpose_node):
2021
perm_attr = transpose_node.get_attr('perm')
21-
return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == [0, 2, 3, 1]
22+
return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == NCHW_TO_NHWC
2223

2324

2425
def is_nchw_transpose(transpose_node):
2526
perm_attr = transpose_node.get_attr('perm')
26-
return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == [0, 3, 1, 2]
27+
return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == NHWC_TO_NCHW
2728

2829

2930
def is_useless_transpose(transpose_node):
@@ -86,9 +87,9 @@ def _calculate_new_shape(graph, op):
8687
# reshape requires tha output shape can only contain one -1, if not some extra op needed.
8788
input_shape = graph.make_node("Shape", [op.input[0]]).output[0]
8889
if is_nchw_transpose(op):
89-
indice = graph.make_const(utils.make_name("indice"), np.array([0, 3, 1, 2])).output[0]
90+
indice = graph.make_const(utils.make_name("indice"), np.array(NHWC_TO_NCHW)).output[0]
9091
else:
91-
indice = graph.make_const(utils.make_name("indice"), np.array([0, 2, 3, 1])).output[0]
92+
indice = graph.make_const(utils.make_name("indice"), np.array(NCHW_TO_NHWC)).output[0]
9293

9394
return graph.make_node("Gather", [input_shape, indice]).output[0]
9495

@@ -245,7 +246,7 @@ def _switch_transpose_and_node(self, node, trans):
245246
shape = self._g.get_shape(node.output[0])
246247
if shape:
247248
# only nhwc transpose can reach here
248-
new_shape = [shape[i] for i in [0, 3, 1, 2]]
249+
new_shape = [shape[i] for i in NHWC_TO_NCHW]
249250
self._g.set_shape(node.output[0], new_shape)
250251
return True
251252

@@ -301,8 +302,8 @@ def _create_transpose_pairs_after_node(self, node):
301302
non_nchw_trans_consumers = self._get_non_nchw_transpose_output_nodes(node)
302303
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nchw_trans_consumers
303304
for consumer in non_nchw_trans_consumers:
304-
nchw_node = self._g.make_node("Transpose", [node.output[0]], attr={"perm": [0, 3, 1, 2]})
305-
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": [0, 2, 3, 1]})
305+
nchw_node = self._g.make_node("Transpose", [node.output[0]], attr={"perm": NHWC_TO_NCHW})
306+
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": NCHW_TO_NHWC})
306307
self._g.replace_input(consumer, node.output[0], nhwc_node.output[0])
307308

308309
def _create_transpose_pairs_before_node(self, node):
@@ -425,7 +426,10 @@ def _identity_handler(self, trans, node):
425426

426427
def _concat_handler(self, trans, node):
427428
if self._handle_node_having_branches(node):
428-
node.set_attr("axis", 1)
429+
perm = trans.get_attr_value("perm")
430+
axis = node.get_attr_value("axis", 0)
431+
new_axis = perm[axis]
432+
node.set_attr("axis", new_axis)
429433
return True
430434
return False
431435

@@ -505,7 +509,7 @@ def _slice_handler(self, trans, node):
505509
axes = axes_node.get_tensor_value(as_list=True)
506510

507511
if axes == [0, 1, 2, 3]:
508-
node.set_attr("axes", [0, 2, 3, 1])
512+
node.set_attr("axes", NCHW_TO_NHWC)
509513
return self._switch_transpose_and_node(node, trans)
510514
return False
511515

0 commit comments

Comments
 (0)