Skip to content

Commit d657622

Browse files
committed
refactor
use NHWC_TO_NCHW instead of [0, 3, 1, 2]
1 parent 7a0903e commit d657622

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def _convert_since_9(cls, ctx, node, op_type):
538538
# scales is nchw
539539
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
540540
# because onnxruntime only supports to scale the last two dims so transpose is inserted
541-
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": [0, 3, 1, 2]})
541+
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
542542
upsample = ctx.make_node(op_type, [input_nchw.output[0], scales.output[0]], attr={"mode": mode})
543543

544544
shapes = node.output_shapes

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def is_nhwc_transpose(transpose_node):
2424

2525
def is_nchw_transpose(transpose_node):
2626
perm_attr = transpose_node.get_attr('perm')
27-
return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == [0, 3, 1, 2]
27+
return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == NHWC_TO_NCHW
2828

2929

3030
def is_useless_transpose(transpose_node):
@@ -87,7 +87,7 @@ def _calculate_new_shape(graph, op):
8787
# reshape requires tha output shape can only contain one -1, if not some extra op needed.
8888
input_shape = graph.make_node("Shape", [op.input[0]]).output[0]
8989
if is_nchw_transpose(op):
90-
indice = graph.make_const(utils.make_name("indice"), np.array([0, 3, 1, 2])).output[0]
90+
indice = graph.make_const(utils.make_name("indice"), np.array(NHWC_TO_NCHW)).output[0]
9191
else:
9292
indice = graph.make_const(utils.make_name("indice"), np.array(NCHW_TO_NHWC)).output[0]
9393

@@ -246,7 +246,7 @@ def _switch_transpose_and_node(self, node, trans):
246246
shape = self._g.get_shape(node.output[0])
247247
if shape:
248248
# only nhwc transpose can reach here
249-
new_shape = [shape[i] for i in [0, 3, 1, 2]]
249+
new_shape = [shape[i] for i in NHWC_TO_NCHW]
250250
self._g.set_shape(node.output[0], new_shape)
251251
return True
252252

@@ -302,7 +302,7 @@ def _create_transpose_pairs_after_node(self, node):
302302
non_nchw_trans_consumers = self._get_non_nchw_transpose_output_nodes(node)
303303
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nchw_trans_consumers
304304
for consumer in non_nchw_trans_consumers:
305-
nchw_node = self._g.make_node("Transpose", [node.output[0]], attr={"perm": [0, 3, 1, 2]})
305+
nchw_node = self._g.make_node("Transpose", [node.output[0]], attr={"perm": NHWC_TO_NCHW})
306306
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": NCHW_TO_NHWC})
307307
self._g.replace_input(consumer, node.output[0], nhwc_node.output[0])
308308

0 commit comments

Comments
 (0)