|
8 | 8 |
|
9 | 9 | import numpy as np
|
10 | 10 |
|
| 11 | +from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW |
11 | 12 | from .. import utils
|
12 | 13 | from .optimizer_base import GraphOptimizerBase
|
13 | 14 |
|
|
18 | 19 |
|
19 | 20 | def is_nhwc_transpose(transpose_node):
|
20 | 21 | perm_attr = transpose_node.get_attr('perm')
|
21 |
| - return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == [0, 2, 3, 1] |
| 22 | + return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == NCHW_TO_NHWC |
22 | 23 |
|
23 | 24 |
|
24 | 25 | def is_nchw_transpose(transpose_node):
|
25 | 26 | perm_attr = transpose_node.get_attr('perm')
|
26 |
| - return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == [0, 3, 1, 2] |
| 27 | + return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == NHWC_TO_NCHW |
27 | 28 |
|
28 | 29 |
|
29 | 30 | def is_useless_transpose(transpose_node):
|
@@ -86,9 +87,9 @@ def _calculate_new_shape(graph, op):
|
86 | 87 | # reshape requires tha output shape can only contain one -1, if not some extra op needed.
|
87 | 88 | input_shape = graph.make_node("Shape", [op.input[0]]).output[0]
|
88 | 89 | if is_nchw_transpose(op):
|
89 |
| - indice = graph.make_const(utils.make_name("indice"), np.array([0, 3, 1, 2])).output[0] |
| 90 | + indice = graph.make_const(utils.make_name("indice"), np.array(NHWC_TO_NCHW)).output[0] |
90 | 91 | else:
|
91 |
| - indice = graph.make_const(utils.make_name("indice"), np.array([0, 2, 3, 1])).output[0] |
| 92 | + indice = graph.make_const(utils.make_name("indice"), np.array(NCHW_TO_NHWC)).output[0] |
92 | 93 |
|
93 | 94 | return graph.make_node("Gather", [input_shape, indice]).output[0]
|
94 | 95 |
|
@@ -245,7 +246,7 @@ def _switch_transpose_and_node(self, node, trans):
|
245 | 246 | shape = self._g.get_shape(node.output[0])
|
246 | 247 | if shape:
|
247 | 248 | # only nhwc transpose can reach here
|
248 |
| - new_shape = [shape[i] for i in [0, 3, 1, 2]] |
| 249 | + new_shape = [shape[i] for i in NHWC_TO_NCHW] |
249 | 250 | self._g.set_shape(node.output[0], new_shape)
|
250 | 251 | return True
|
251 | 252 |
|
@@ -301,8 +302,8 @@ def _create_transpose_pairs_after_node(self, node):
|
301 | 302 | non_nchw_trans_consumers = self._get_non_nchw_transpose_output_nodes(node)
|
302 | 303 | # add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nchw_trans_consumers
|
303 | 304 | for consumer in non_nchw_trans_consumers:
|
304 |
| - nchw_node = self._g.make_node("Transpose", [node.output[0]], attr={"perm": [0, 3, 1, 2]}) |
305 |
| - nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": [0, 2, 3, 1]}) |
| 305 | + nchw_node = self._g.make_node("Transpose", [node.output[0]], attr={"perm": NHWC_TO_NCHW}) |
| 306 | + nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": NCHW_TO_NHWC}) |
306 | 307 | self._g.replace_input(consumer, node.output[0], nhwc_node.output[0])
|
307 | 308 |
|
308 | 309 | def _create_transpose_pairs_before_node(self, node):
|
@@ -425,7 +426,10 @@ def _identity_handler(self, trans, node):
|
425 | 426 |
|
426 | 427 | def _concat_handler(self, trans, node):
|
427 | 428 | if self._handle_node_having_branches(node):
|
428 |
| - node.set_attr("axis", 1) |
| 429 | + perm = trans.get_attr_value("perm") |
| 430 | + axis = node.get_attr_value("axis", 0) |
| 431 | + new_axis = perm[axis] |
| 432 | + node.set_attr("axis", new_axis) |
429 | 433 | return True
|
430 | 434 | return False
|
431 | 435 |
|
@@ -505,7 +509,7 @@ def _slice_handler(self, trans, node):
|
505 | 509 | axes = axes_node.get_tensor_value(as_list=True)
|
506 | 510 |
|
507 | 511 | if axes == [0, 1, 2, 3]:
|
508 |
| - node.set_attr("axes", [0, 2, 3, 1]) |
| 512 | + node.set_attr("axes", NCHW_TO_NHWC) |
509 | 513 | return self._switch_transpose_and_node(node, trans)
|
510 | 514 | return False
|
511 | 515 |
|
|
0 commit comments