Skip to content

Commit 7a0903e

Browse files
committed
refactor
use NCHW_TO_NHWC instead of [0, 2, 3, 1]
1 parent a415293 commit 7a0903e

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def _convert_since_9(cls, ctx, node, op_type):
544544
shapes = node.output_shapes
545545
dtypes = node.output_dtypes
546546
ctx.remove_node(node.name)
547-
ctx.make_node("Transpose", upsample.output, {"perm": [0, 2, 3, 1]},
547+
ctx.make_node("Transpose", upsample.output, {"perm": constants.NCHW_TO_NHWC},
548548
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
549549

550550

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010

11+
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW
1112
from .. import utils
1213
from .optimizer_base import GraphOptimizerBase
1314

@@ -18,7 +19,7 @@
1819

1920
def is_nhwc_transpose(transpose_node):
2021
perm_attr = transpose_node.get_attr('perm')
21-
return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == [0, 2, 3, 1]
22+
return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == NCHW_TO_NHWC
2223

2324

2425
def is_nchw_transpose(transpose_node):
@@ -88,7 +89,7 @@ def _calculate_new_shape(graph, op):
8889
if is_nchw_transpose(op):
8990
indice = graph.make_const(utils.make_name("indice"), np.array([0, 3, 1, 2])).output[0]
9091
else:
91-
indice = graph.make_const(utils.make_name("indice"), np.array([0, 2, 3, 1])).output[0]
92+
indice = graph.make_const(utils.make_name("indice"), np.array(NCHW_TO_NHWC)).output[0]
9293

9394
return graph.make_node("Gather", [input_shape, indice]).output[0]
9495

@@ -302,7 +303,7 @@ def _create_transpose_pairs_after_node(self, node):
302303
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nchw_trans_consumers
303304
for consumer in non_nchw_trans_consumers:
304305
nchw_node = self._g.make_node("Transpose", [node.output[0]], attr={"perm": [0, 3, 1, 2]})
305-
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": [0, 2, 3, 1]})
306+
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": NCHW_TO_NHWC})
306307
self._g.replace_input(consumer, node.output[0], nhwc_node.output[0])
307308

308309
def _create_transpose_pairs_before_node(self, node):
@@ -505,7 +506,7 @@ def _slice_handler(self, trans, node):
505506
axes = axes_node.get_tensor_value(as_list=True)
506507

507508
if axes == [0, 1, 2, 3]:
508-
node.set_attr("axes", [0, 2, 3, 1])
509+
node.set_attr("axes", NCHW_TO_NHWC)
509510
return self._switch_transpose_and_node(node, trans)
510511
return False
511512

0 commit comments

Comments
 (0)