|
8 | 8 |
|
9 | 9 | import numpy as np
|
10 | 10 |
|
| 11 | +from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW |
11 | 12 | from .. import utils
|
12 | 13 | from .optimizer_base import GraphOptimizerBase
|
13 | 14 |
|
|
18 | 19 |
|
19 | 20 | def is_nhwc_transpose(transpose_node):
|
20 | 21 | perm_attr = transpose_node.get_attr('perm')
|
21 |
| - return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == [0, 2, 3, 1] |
| 22 | + return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == NCHW_TO_NHWC |
22 | 23 |
|
23 | 24 |
|
24 | 25 | def is_nchw_transpose(transpose_node):
|
@@ -88,7 +89,7 @@ def _calculate_new_shape(graph, op):
|
88 | 89 | if is_nchw_transpose(op):
|
89 | 90 | indice = graph.make_const(utils.make_name("indice"), np.array([0, 3, 1, 2])).output[0]
|
90 | 91 | else:
|
91 |
| - indice = graph.make_const(utils.make_name("indice"), np.array([0, 2, 3, 1])).output[0] |
| 92 | + indice = graph.make_const(utils.make_name("indice"), np.array(NCHW_TO_NHWC)).output[0] |
92 | 93 |
|
93 | 94 | return graph.make_node("Gather", [input_shape, indice]).output[0]
|
94 | 95 |
|
@@ -302,7 +303,7 @@ def _create_transpose_pairs_after_node(self, node):
|
302 | 303 | # add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nchw_trans_consumers
|
303 | 304 | for consumer in non_nchw_trans_consumers:
|
304 | 305 | nchw_node = self._g.make_node("Transpose", [node.output[0]], attr={"perm": [0, 3, 1, 2]})
|
305 |
| - nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": [0, 2, 3, 1]}) |
| 306 | + nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": NCHW_TO_NHWC}) |
306 | 307 | self._g.replace_input(consumer, node.output[0], nhwc_node.output[0])
|
307 | 308 |
|
308 | 309 | def _create_transpose_pairs_before_node(self, node):
|
@@ -505,7 +506,7 @@ def _slice_handler(self, trans, node):
|
505 | 506 | axes = axes_node.get_tensor_value(as_list=True)
|
506 | 507 |
|
507 | 508 | if axes == [0, 1, 2, 3]:
|
508 |
| - node.set_attr("axes", [0, 2, 3, 1]) |
| 509 | + node.set_attr("axes", NCHW_TO_NHWC) |
509 | 510 | return self._switch_transpose_and_node(node, trans)
|
510 | 511 | return False
|
511 | 512 |
|
|
0 commit comments