Skip to content

Commit 98546b1

Browse files
Add target for converting to channels last format (#1581)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent d4be5af commit 98546b1

File tree

8 files changed

+183
-14
lines changed

8 files changed

+183
-14
lines changed

tf2onnx/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
TARGET_RS6 = "rs6"
3030
TARGET_CAFFE2 = "caffe2"
3131
TARGET_TENSORRT = "tensorrt"
32-
POSSIBLE_TARGETS = [TARGET_RS4, TARGET_RS5, TARGET_RS6, TARGET_CAFFE2, TARGET_TENSORRT]
32+
TARGET_CHANNELS_LAST = "nhwc"
33+
TARGET_CHANNELS_FIRST = "nchw"
34+
POSSIBLE_TARGETS = [TARGET_RS4, TARGET_RS5, TARGET_RS6, TARGET_CAFFE2, TARGET_TENSORRT, TARGET_CHANNELS_LAST]
3335
DEFAULT_TARGET = []
3436

3537
NCHW_TO_NHWC = [0, 2, 3, 1]

tf2onnx/graph.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,13 +1124,29 @@ def make_graph(self, doc, graph_name=None, external_tensor_storage=None):
11241124
# create output_tensor_values
11251125
output_tensor_values = self.make_onnx_graph_io(self.outputs)
11261126

1127+
tensor_value_info = []
1128+
1129+
for op in ops:
1130+
if op.domain in [constants.ONNX_DOMAIN, constants.AI_ONNX_ML_DOMAIN]:
1131+
continue
1132+
# We still don't 100% trust the accuracy of all the shapes in graph.py, but for custom ops they are
1133+
# almost certainly accurate and onnx has no other way of knowing them.
1134+
for out in op.output:
1135+
if out == '' or out in self.outputs:
1136+
continue
1137+
dtype = self.get_dtype(out)
1138+
shape = self.get_shape(out)
1139+
v = utils.make_onnx_inputs_outputs(out, dtype, shape)
1140+
tensor_value_info.append(v)
1141+
11271142
# create graph proto
11281143
graph = helper.make_graph([op.op for op in ops],
11291144
graph_name,
11301145
input_tensor_values,
11311146
output_tensor_values,
11321147
initializer=initializers,
1133-
doc_string=doc)
1148+
doc_string=doc,
1149+
value_info=tensor_value_info)
11341150

11351151
return graph
11361152

@@ -1628,10 +1644,11 @@ def get_onnx_model_properties(onnx_model_proto):
16281644
return kwargs
16291645

16301646
@staticmethod
1631-
def create_graph_from_onnx_model(onnx_model_proto):
1647+
def create_graph_from_onnx_model(onnx_model_proto, target=None):
16321648
"""Create Graph loading onnx model proto."""
16331649
# apply shape inference on the model
16341650
inferred_model = shape_inference.infer_shapes(onnx_model_proto)
1651+
utils.initialize_name_counter(inferred_model)
16351652
graph_proto = inferred_model.graph
16361653

16371654
opset_version = None
@@ -1644,11 +1661,11 @@ def create_graph_from_onnx_model(onnx_model_proto):
16441661
extra_opset.append(opset)
16451662

16461663
utils.make_sure(opset_version is not None, "opset version is not specified for onnx domain")
1647-
main_graph = GraphUtil.create_graph_from_onnx_graph(graph_proto, opset_version, extra_opset)
1664+
main_graph = GraphUtil.create_graph_from_onnx_graph(graph_proto, opset_version, extra_opset, target)
16481665
return main_graph
16491666

16501667
@staticmethod
1651-
def create_graph_from_onnx_graph(graph_proto, opset_version=None, extra_opset=None):
1668+
def create_graph_from_onnx_graph(graph_proto, opset_version=None, extra_opset=None, target=None):
16521669
"""Create Graph loading onnx graph proto."""
16531670
output_shapes = {}
16541671
output_dtypes = {}
@@ -1675,7 +1692,7 @@ def create_graph_from_onnx_graph(graph_proto, opset_version=None, extra_opset=No
16751692
for n in graph_proto.output:
16761693
output_names.append(n.name)
16771694

1678-
g = Graph(nodes_to_append, output_shapes, output_dtypes, None, opset_version, extra_opset, None, output_names)
1695+
g = Graph(nodes_to_append, output_shapes, output_dtypes, target, opset_version, extra_opset, None, output_names)
16791696
const_nodes = GraphUtil._parse_graph_initializer(g, graph_proto)
16801697
GraphUtil._parse_graph_input(g, graph_proto, [n.name for n in const_nodes])
16811698

@@ -1702,6 +1719,10 @@ def _parse_shape_and_type_from_value_infos(value_infos):
17021719
for shape_info in value_infos:
17031720
type_proto = shape_info.type
17041721
elem_type = type_proto.tensor_type.elem_type
1722+
output_dtypes[shape_info.name] = elem_type
1723+
if not type_proto.tensor_type.HasField("shape"):
1724+
output_shapes[shape_info.name] = None
1725+
continue
17051726
shape = type_proto.tensor_type.shape
17061727
tuned_shape = []
17071728
for d in shape.dim:
@@ -1713,7 +1734,6 @@ def _parse_shape_and_type_from_value_infos(value_infos):
17131734
# it is found, some unknown dims is missing after inference.
17141735
tuned_shape.append(-1)
17151736
output_shapes[shape_info.name] = tuned_shape
1716-
output_dtypes[shape_info.name] = elem_type
17171737

17181738
return output_shapes, output_dtypes
17191739

tf2onnx/late_rewriters/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
"""tf2onnx.late_rewriters module."""
4+
5+
from tf2onnx.late_rewriters.channel_order_rewriters import rewrite_channels_first, rewrite_channels_last
6+
7+
8+
__all__ = [
9+
"rewrite_channels_first",
10+
"rewrite_channels_last",
11+
]
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.late_rewriters.channel_order_rewriters - contains rewriters for replacing ops with channel first/last versions
6+
"""
7+
8+
from tf2onnx import utils, constants
9+
10+
# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable
11+
12+
_CHANNELS_FIRST_OPS = [
13+
"AveragePool",
14+
"BatchNormalization",
15+
"Conv",
16+
"ConvInteger",
17+
"ConvTranspose",
18+
"GlobalAveragePool",
19+
"GlobalLpPool",
20+
"GlobalMaxPool",
21+
"InstanceNormalization",
22+
"LpPool",
23+
"LRN",
24+
"MaxPool",
25+
"MaxRoiPool",
26+
"MaxUnpool",
27+
"QLinearConv",
28+
]
29+
30+
31+
def channel_last_to_first_perm(rank):
32+
return [0, rank - 1] + list(range(1, rank - 1))
33+
34+
35+
def channel_first_to_last_perm(rank):
36+
return [0] + list(range(2, rank)) + [1]
37+
38+
39+
def _to_channel_last_handler(g, op):
40+
# For now, all ops can use the same handlers (input[0] and output[0] are always correct)
41+
rank = g.get_rank(op.output[0])
42+
utils.make_sure(rank is not None, "Cannot convert %s node %s with unknown rank to channels last", op.type, op.name)
43+
op.type = "ChannelsLast" + op.type
44+
op.domain = constants.CONTRIB_OPS_DOMAIN
45+
inp_perm = channel_first_to_last_perm(rank)
46+
out_perm = channel_last_to_first_perm(rank)
47+
output_shape = g.get_shape(op.output[0])
48+
if output_shape is not None:
49+
output_shape = [output_shape[i] for i in inp_perm]
50+
g.set_shape(op.output[0], output_shape)
51+
52+
g.insert_new_node_on_input(op, "Transpose", op.input[0], input_index=0, perm=inp_perm)
53+
g.insert_new_node_on_output("Transpose", op.output[0], perm=out_perm)
54+
55+
56+
def _to_channel_first_handler(g, op):
57+
rank = g.get_rank(op.output[0])
58+
utils.make_sure(rank is not None, "Cannot convert %s node %s with unknown rank to channels last", op.type, op.name)
59+
op.type = op.type.replace("ChannelsLast", "")
60+
op.domain = constants.ONNX_DOMAIN
61+
inp_perm = channel_last_to_first_perm(rank)
62+
out_perm = channel_first_to_last_perm(rank)
63+
output_shape = g.get_shape(op.output[0])
64+
if output_shape is not None:
65+
output_shape = [output_shape[i] for i in inp_perm]
66+
g.set_shape(op.output[0], output_shape)
67+
68+
g.insert_new_node_on_input(op, "Transpose", op.input[0], input_index=0, perm=inp_perm)
69+
g.insert_new_node_on_output("Transpose", op.output[0], perm=out_perm)
70+
71+
72+
def get_channels_first_ops(opset=None):
73+
# opset doesn't matter for now
74+
return set(_CHANNELS_FIRST_OPS)
75+
76+
77+
def rewrite_channels_last(g, ops):
78+
channel_first_ops = get_channels_first_ops(g.opset)
79+
for op in ops:
80+
if op.type in channel_first_ops:
81+
_to_channel_last_handler(g, op)
82+
return g.get_nodes()
83+
84+
85+
def rewrite_channels_first(g, ops):
86+
for op in ops:
87+
if op.type.startswith("ChannelsLast"):
88+
_to_channel_first_handler(g, op)
89+
return g.get_nodes()

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
import onnx
10-
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW, NCDHW_TO_NDHWC, NDHWC_TO_NCDHW
10+
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW, NCDHW_TO_NDHWC, NDHWC_TO_NCDHW, TARGET_CHANNELS_LAST
1111
from .. import utils
1212
from .optimizer_base import GraphOptimizerBase
1313

@@ -362,14 +362,19 @@ def _should_push_transpose(self, trans, node):
362362
perm = trans.get_attr_value("perm")
363363
optimization_gains = 0
364364
removed_nchws = 0
365+
perm_to_push_down = [NCHW_TO_NHWC, NCDHW_TO_NDHWC]
366+
perm_to_push_up = [NHWC_TO_NCHW, NDHWC_TO_NCDHW]
367+
if self._g.is_target(TARGET_CHANNELS_LAST):
368+
perm_to_push_down, perm_to_push_up = perm_to_push_up, perm_to_push_down
369+
365370
for n, inp_id in zip(node.inputs, node.input):
366371
if is_tranpose_of_type(n, perm):
367372
optimization_gains += self._cost_to_transpose(n.inputs[0], n.input[0])
368-
if perm in [NCHW_TO_NHWC, NCDHW_TO_NDHWC]:
373+
if perm in perm_to_push_down:
369374
removed_nchws += 1
370375
else:
371376
optimization_gains -= self._cost_to_transpose(n, inp_id)
372-
if perm in [NHWC_TO_NCHW, NDHWC_TO_NCDHW]:
377+
if perm in perm_to_push_up:
373378
removed_nchws -= 1
374379
if removed_nchws != 0:
375380
# Always push nchw transposes if possible

tf2onnx/tfonnx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tf2onnx.graph import Graph
2020
from tf2onnx.rewriter import * # pylint: disable=wildcard-import
2121
from tf2onnx.tflite_rewriters import * # pylint: disable=wildcard-import
22+
from tf2onnx.late_rewriters import rewrite_channels_last
2223
from tf2onnx.shape_inference import infer_shape
2324
from tf2onnx.tf_loader import is_function, resolve_functions, set_function
2425
from tf2onnx.tf_utils import tensorflow_to_onnx, get_tf_version, compute_const_folding_using_tf
@@ -640,6 +641,8 @@ def compat_handler(ctx, node, **kwargs):
640641
late_rewriters.append(rewrite_incomplete_type_support_rs5)
641642
if constants.TARGET_RS6 in target:
642643
late_rewriters.append(rewrite_incomplete_type_support_rs6)
644+
if constants.TARGET_CHANNELS_LAST in target:
645+
late_rewriters.append(rewrite_channels_last)
643646
if late_rewriters:
644647
run_rewriters(g, late_rewriters, continue_on_error)
645648

tf2onnx/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,35 @@ def find_opset(opset):
184184
return opset
185185

186186

187+
def get_subgraphs_from_onnx(model_proto):
188+
"""Returns an iterator over the graphs/subgraphs of a model (using dfs)"""
189+
stack = [model_proto.graph]
190+
while stack:
191+
g = stack.pop()
192+
yield g
193+
for node in g.node:
194+
for attr in node.attribute:
195+
if hasattr(attr, "g"):
196+
stack.append(attr.g)
197+
if hasattr(attr, "graphs"):
198+
stack.extend(attr.graphs)
199+
200+
201+
def initialize_name_counter(model_proto):
202+
"""Avoid name conflicts by initializing the counter used by make_name based on the provided model"""
203+
suffix_regex = re.compile(r"__(\d+)(:\d+)?$")
204+
def avoid_name(name):
205+
global INTERNAL_NAME
206+
suffix = suffix_regex.search(name)
207+
if suffix:
208+
INTERNAL_NAME = max(INTERNAL_NAME, int(suffix.group(1)) + 1)
209+
for g in get_subgraphs_from_onnx(model_proto):
210+
for n in g.node:
211+
avoid_name(n.name)
212+
for out in n.output:
213+
avoid_name(out)
214+
215+
187216
def save_onnx_model(save_path_root, onnx_file_name, feed_dict, model_proto, include_test_data=False, as_text=False,
188217
external_tensor_storage=None):
189218
"""Save onnx model as file. Save a pbtxt file as well if as_text is True"""

tools/onnx-optimize.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from onnx import helper
1717

1818
from tf2onnx.graph import GraphUtil
19-
from tf2onnx import logging, optimizer
19+
from tf2onnx import logging, optimizer, constants
20+
from tf2onnx.late_rewriters import rewrite_channels_first, rewrite_channels_last
2021

2122

2223
logging.basicConfig(level=logging.INFO)
@@ -28,23 +29,32 @@ def get_args():
2829
parser = argparse.ArgumentParser()
2930
parser.add_argument("--input", required=True, help="onnx input model file")
3031
parser.add_argument("--output", help="output model file")
32+
target_options = [constants.TARGET_CHANNELS_LAST, constants.TARGET_CHANNELS_FIRST]
33+
parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=target_options,
34+
help="target platform")
3135
args = parser.parse_args()
36+
args.target = args.target.split(",")
3237
return args
3338

3439

35-
def load_graph(fname):
40+
def load_graph(fname, target):
3641
model_proto = onnx.ModelProto()
3742
with open(fname, "rb") as f:
3843
data = f.read()
3944
model_proto.ParseFromString(data)
40-
g = GraphUtil.create_graph_from_onnx_model(model_proto)
45+
g = GraphUtil.create_graph_from_onnx_model(model_proto, target)
4146
return g, model_proto
4247

4348

4449
def main():
4550
args = get_args()
4651

47-
g, org_model_proto = load_graph(args.input)
52+
g, org_model_proto = load_graph(args.input, args.target)
53+
54+
if g.is_target(constants.TARGET_CHANNELS_FIRST):
55+
g.reset_nodes(rewrite_channels_first(g, g.get_nodes()))
56+
if g.is_target(constants.TARGET_CHANNELS_LAST):
57+
g.reset_nodes(rewrite_channels_last(g, g.get_nodes()))
4858

4959
g = optimizer.optimize_graph(g)
5060

0 commit comments

Comments
 (0)