Skip to content

Commit ee69f28

Browse files
auto infer shape and dtype for onnx nodes
1 parent d03e469 commit ee69f28

File tree

9 files changed

+585
-27
lines changed

9 files changed

+585
-27
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def run_onnxruntime(self, model_path, inputs, output_names):
6969
results = m.run(output_names, inputs)
7070
return results
7171

72-
def _run_backend(self, g, outputs, input_dict):
72+
def run_backend(self, g, outputs, input_dict):
7373
model_proto = g.make_model("test")
7474
model_path = self.save_onnx_model(model_proto, input_dict)
7575

@@ -133,7 +133,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
133133
g = process_tf_graph(sess.graph, opset=self.config.opset, output_names=output_names_with_port,
134134
target=self.config.target, **process_args)
135135
g = optimizer.optimize_graph(g)
136-
actual = self._run_backend(g, output_names_with_port, onnx_feed_dict)
136+
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict)
137137

138138
for expected_val, actual_val in zip(expected, actual):
139139
if check_value:

tests/test_shape_inference.py

Lines changed: 424 additions & 0 deletions
Large diffs are not rendered by default.

tf2onnx/graph.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
import collections
1313
import copy
1414
import logging
15+
import traceback
1516
import six
1617
import numpy as np
1718

18-
from onnx import helper, numpy_helper, shape_inference, OperatorSetIdProto, AttributeProto
19+
from onnx import helper, numpy_helper, shape_inference, OperatorSetIdProto, AttributeProto, TensorProto
20+
from tf2onnx import constants
1921
from tf2onnx import utils, __version__
2022
from tf2onnx.utils import port_name, find_opset
2123
from tf2onnx import optimizer
22-
from tf2onnx.schemas import get_schema
24+
from tf2onnx.schemas import get_schema, infer_onnx_shape_dtype
2325

2426
logger = logging.getLogger(__name__)
2527

@@ -419,7 +421,7 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
419421
return node
420422

421423
def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, skip_conversion=True,
422-
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=None):
424+
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=None, auto_infer_shape_dtype=True):
423425
"""Make a new onnx node in the graph"""
424426
if attr is None:
425427
attr = {}
@@ -474,6 +476,9 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
474476
for i in range(output_count):
475477
self.set_dtype(node.output[i], dtypes[i])
476478

479+
if (not shapes or not dtypes) and auto_infer_shape_dtype:
480+
self.update_node_shape_dtype(node, override=True)
481+
477482
self._nodes.append(node)
478483
return node
479484

@@ -533,6 +538,68 @@ def reset_nodes(self, ops):
533538
self._dtypes = remained_dtypes
534539
self._output_shapes = remained_shapes
535540

541+
def update_node_shape_dtype(self, node, override=True):
542+
"""try the best to infer shapes and dtypes for outputs of the node"""
543+
if node.is_const() or node.is_graph_input():
544+
return
545+
# NOTE: only support onnx node for now
546+
if node.domain != constants.ONNX_DOMAIN:
547+
return
548+
549+
logger.debug("Infer shape and dtype for [%s]", node.name)
550+
# NOTE: shape inference for some ops need the input values of the op, e.g., Reshape
551+
# op needs the "Shape" value to infer output shape.
552+
initializer = []
553+
for i, inp in enumerate(node.inputs):
554+
if not inp:
555+
logger.warning("[%s] infer a inexistent node: [%s], please check the code", node.name, node.input[i])
556+
continue
557+
if inp.is_const():
558+
t = inp.get_attr("value")
559+
tensor = helper.get_attribute_value(t)
560+
tensor.name = inp.output[0]
561+
initializer.append(tensor)
562+
563+
input_shapes = [self.get_shape(i) for i in node.input]
564+
input_dtypes = [self.get_dtype(i) for i in node.input]
565+
566+
dtypes = {}
567+
shapes = {}
568+
try:
569+
shapes, dtypes = infer_onnx_shape_dtype(node, input_shapes, input_dtypes, self._opset, initializer)
570+
except Exception:
571+
tb = traceback.format_exc()
572+
logger.warning("ONNX Failed to infer shapes and dtypes for [%s, type: %s]", node.name, node.type)
573+
logger.warning("Inference error: %s", tb)
574+
return
575+
576+
for output in node.output:
577+
dtype = dtypes[output]
578+
shape = shapes[output]
579+
if dtype == TensorProto.UNDEFINED:
580+
logger.debug("Inferred dtype for [%s, type: %s] is UNDEFINED, SKIP", node.name, node.type)
581+
else:
582+
existing_dtype = self.get_dtype(output)
583+
if existing_dtype is not None and existing_dtype != dtype:
584+
if override:
585+
logger.warning("Override dtype of %s from %s to %s", output, existing_dtype, dtype)
586+
else:
587+
dtype = existing_dtype
588+
self.set_dtype(output, dtype)
589+
logger.debug("Set dtype of [%s] to %s", output, dtype)
590+
591+
if shape is None:
592+
logger.debug("Inferred shape for [%s, type: %s] is None, SKIP", node.name, node.type)
593+
else:
594+
existing_shape = self.get_shape(output)
595+
if existing_shape is not None and not utils.are_shapes_equal(existing_shape, shape):
596+
if override:
597+
logger.warning("Override shape of %s from %s to %s", output, existing_shape, shape)
598+
else:
599+
shape = existing_shape
600+
self.set_shape(output, shape)
601+
logger.debug("Set shape of [%s] to %s", output, shape)
602+
536603
def update_proto(self):
537604
"""Update the onnx protobuf from out internal Node structure."""
538605
for node in self._nodes:

tf2onnx/onnx_opset/controlflow.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def get_inputs_for_current_iteration(g, input_id, iter_index):
3434
def create_loop_body_graph(parent_g, gather_input_ids, output_data_type, output_shape, trip_count_input_ids,
3535
rank, loop_name):
3636
g = parent_g.create_new_graph_with_same_config()
37+
g.parent_graph = parent_g
3738
iter_name = utils.make_name("i")
3839
cond_name = utils.make_name("cond")
3940
fake_var_name = utils.make_name("fake_var")
@@ -117,6 +118,7 @@ def create_if_op(g, input_ids, output_data_type, output_shape):
117118

118119
def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cur_cond_val_out_name, op_name):
119120
g = parent_g.create_new_graph_with_same_config()
121+
g.parent_graph = parent_g
120122
name = utils.make_name("Identity")
121123
g.make_node(
122124
'Identity',
@@ -204,14 +206,15 @@ def make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dt
204206

205207
# body
206208
g = ctx.create_new_graph_with_same_config()
207-
g.make_node("Identity", ["cond"], outputs=["cond_out"])
208-
g.make_node("Add", ["prev", delta], outputs=["current"], name=utils.make_name("add"))
209-
g.make_node("Identity", ["prev"], outputs=["range"])
210-
209+
g.parent_graph = ctx
211210
g.add_graph_input("i", TensorProto.INT64, [])
212211
g.add_graph_input("cond", TensorProto.BOOL, [])
213212
g.add_graph_input("prev", dtype, [])
214213

214+
g.make_node("Identity", ["cond"], outputs=["cond_out"])
215+
g.make_node("Add", ["prev", delta], outputs=["current"], name=utils.make_name("add"))
216+
g.make_node("Identity", ["prev"], outputs=["range"])
217+
215218
g.add_graph_output("cond_out", TensorProto.BOOL, [])
216219
g.add_graph_output("current", dtype, [])
217220
g.add_graph_output("range", dtype, [])
@@ -274,8 +277,9 @@ def version_8(cls, ctx, node, **kwargs):
274277
input_shape = ctx.get_shape(node.input[0])
275278

276279
g = ctx.create_new_graph_with_same_config()
277-
g.make_node('Identity', ['X'], outputs=['Y'])
280+
g.parent_graph = ctx
278281
g.add_graph_input('X', input_dtype, input_shape[2:])
282+
g.make_node('Identity', ['X'], outputs=['Y'])
279283
g.add_graph_output('Y', input_dtype, input_shape[2:])
280284

281285
node.set_body_graph_as_attr("body", g)

tf2onnx/onnx_opset/nn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
110110
shape_name = utils.make_name(node.name)
111111
ctx.make_const(shape_name, np.array(new_kernel_shape, dtype=np.int64))
112112
input_name = node.input[1]
113-
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name)
114-
reshape.input.append(shape_name)
113+
reshape = ctx.make_node("Reshape", [input_name, shape_name])
114+
ctx.replace_input(node, input_name, reshape.output[0])
115115
reshape.skip_conversion = True
116116
ctx.set_shape(reshape.output[0], new_kernel_shape)
117117

@@ -348,8 +348,8 @@ def version_7(cls, ctx, node, **kwargs):
348348
shape_name = utils.make_name(node.name)
349349
ctx.make_const(shape_name, np.array(new_broadcast_shape, dtype=np.int64))
350350
op_name = node.input[1]
351-
reshape_node = ctx.insert_new_node_on_input(node, "Reshape", op_name)
352-
reshape_node.input.append(shape_name)
351+
reshape_node = ctx.make_node("Reshape", [op_name, shape_name])
352+
ctx.replace_input(node, op_name, reshape_node.output[0])
353353
ctx.set_shape(reshape_node.output[0], new_broadcast_shape)
354354

355355

tf2onnx/onnx_opset/tensor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,16 @@ def _make_gathernd_inner_loop(ctx, params, index, dtype):
283283

284284
# body graph creation
285285
g = ctx.create_new_graph_with_same_config()
286+
g.add_graph_input(trip_name, TensorProto.INT64, [])
287+
g.add_graph_input(cond_name, TensorProto.BOOL, [])
288+
g.add_graph_input(cur_name, dtype, [])
289+
g.parent_graph = ctx
290+
286291
index_i = g.make_node("Gather", [index.output[0], trip_name], attr={"axis": 0})
287292
gather = g.make_node("Gather", [cur_name, index_i.output[0]], attr={"axis": 0})
288293
g.make_node("Squeeze", [gather.output[0]], attr={"axes": [0]}, outputs=[result_name])
289294
g.make_node("Identity", [cond_name], outputs=[cond_out_name])
290295

291-
g.add_graph_input(trip_name, TensorProto.INT64, [])
292-
g.add_graph_input(cond_name, TensorProto.BOOL, [])
293-
g.add_graph_input(cur_name, dtype, [])
294-
295296
g.add_graph_output(cond_out_name, TensorProto.BOOL, [])
296297
g.add_graph_output(result_name, dtype, [])
297298

@@ -337,6 +338,11 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
337338
dummy_out_name = utils.make_name("dummy_out")
338339
result_name = utils.make_name("res")
339340

341+
g.add_graph_input(trip_name, TensorProto.INT64, [])
342+
g.add_graph_input(cond_name, TensorProto.BOOL, [])
343+
g.add_graph_input(dummy_name, t_params, [])
344+
g.parent_graph = ctx
345+
340346
index = g.make_node("Gather", [flatten_indices.output[0], trip_name], attr={"axis": 0})
341347
index_squeeze = g.make_node("Squeeze", [index.output[0]], attr={"axes": [0]})
342348
# inner loop to gather result
@@ -345,10 +351,6 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
345351
g.make_node("Identity", [dummy_name], outputs=[dummy_out_name])
346352
g.make_node("Identity", [inner_loop.output[0]], outputs=[result_name])
347353

348-
g.add_graph_input(trip_name, TensorProto.INT64, [])
349-
g.add_graph_input(cond_name, TensorProto.BOOL, [])
350-
g.add_graph_input(dummy_name, t_params, [])
351-
352354
g.add_graph_output(cond_out_name, TensorProto.BOOL, [])
353355
g.add_graph_output(dummy_out_name, t_params, [])
354356
g.add_graph_output(result_name, t_params, [])

tf2onnx/schemas.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
from __future__ import print_function
1010
from __future__ import unicode_literals
1111

12+
import copy
1213
from collections import defaultdict, OrderedDict
13-
from onnx import defs
14+
from onnx import defs, helper, TensorProto, OperatorSetIdProto, shape_inference
1415

1516
from . import constants
17+
from . import utils
1618

1719

1820
class OnnxOpSchema(object):
@@ -112,3 +114,55 @@ def get_max_supported_opset_version(domain=None):
112114
"""Get max supported opset version by current onnx package given a domain."""
113115
domain = domain or constants.ONNX_DOMAIN
114116
return _domain_opset_versions.get(domain, None)
117+
118+
119+
def infer_onnx_shape_dtype(node, input_shapes, input_dtypes, opset, initializer=None):
120+
"""
121+
Infer shapes and dtypes for outputs of the node.
122+
Sometimes, shape inference needs the values of node's inputs, so initializers are used.
123+
"""
124+
125+
def build_onnx_op(node):
126+
"""Build onnx op"""
127+
onnx_node = helper.make_node(node.type, node.input, node.output, name=node.name)
128+
# deal with attributes
129+
attr = []
130+
attr_graphs = node.get_body_graphs()
131+
if attr_graphs:
132+
for attr_name, sub_graph in attr_graphs.items():
133+
copied_sub_graph = copy.deepcopy(sub_graph)
134+
graph_proto = copied_sub_graph.make_graph("graph for " + node.name + " " + attr_name)
135+
attr.append(helper.make_attribute(attr_name, graph_proto))
136+
attr.extend([a for a in node.attr_onnx.values()])
137+
if attr:
138+
onnx_node.attribute.extend(attr)
139+
return onnx_node
140+
141+
shapes = {}
142+
dtypes = {}
143+
inputs = []
144+
outputs = []
145+
for inp, shape, dtype in zip(node.input, input_shapes, input_dtypes):
146+
inputs.append(utils.make_onnx_inputs_outputs(inp, dtype, shape))
147+
for output in node.output:
148+
outputs.append(utils.make_onnx_inputs_outputs(output, TensorProto.UNDEFINED, None))
149+
graph_def = helper.make_graph([build_onnx_op(node)], "infer-graph", inputs, outputs, initializer=initializer)
150+
imp = OperatorSetIdProto()
151+
imp.version = opset
152+
model_def = helper.make_model(graph_def, opset_imports=[imp])
153+
154+
inferred_model = shape_inference.infer_shapes(model_def)
155+
for output in inferred_model.graph.output:
156+
tensor_type = output.type.tensor_type
157+
if tensor_type.HasField("elem_type"):
158+
dtypes[output.name] = tensor_type.elem_type
159+
else:
160+
dtypes[output.name] = TensorProto.UNDEFINED
161+
# 0 in shapes of onnx means unknown which is -1 in our convertor
162+
if tensor_type.HasField("shape"):
163+
shapes[output.name] = [
164+
dim.dim_value if dim.dim_value != 0 else utils.ONNX_UNKNOWN_DIMENSION for dim in tensor_type.shape.dim
165+
]
166+
else:
167+
shapes[output.name] = None
168+
return shapes, dtypes

tf2onnx/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,14 @@ def make_onnx_inputs_outputs(name, elem_type, shape, **kwargs):
244244
elem_type, # type: TensorProto.DataType
245245
shape, # type: Optional[Sequence[int]]
246246
"""
247-
return helper.make_tensor_value_info(name, elem_type, make_onnx_shape(shape), **kwargs)
247+
if elem_type is None:
248+
elem_type = onnx_pb.TensorProto.UNDEFINED
249+
return helper.make_tensor_value_info(
250+
name,
251+
elem_type,
252+
make_onnx_shape(shape),
253+
**kwargs
254+
)
248255

249256

250257
def find_opset(opset):
@@ -308,7 +315,7 @@ def construct_graph_from_nodes(parent_g, nodes, outputs, shapes, dtypes):
308315
all_outputs |= set(op.output)
309316

310317
new_node = g.make_node(op.type, op.input, outputs=op.output, attr=op.attr, name=op.name,
311-
skip_conversion=op.skip_conversion)
318+
skip_conversion=op.skip_conversion, auto_infer_shape_dtype=False)
312319
body_graphs = op.graph.contained_graphs.pop(op.name, None)
313320
if body_graphs:
314321
for attr_name, body_graph in body_graphs.items():
@@ -327,7 +334,7 @@ def construct_graph_from_nodes(parent_g, nodes, outputs, shapes, dtypes):
327334
new_output_names = []
328335
for output, shape, dtype in zip(outputs, shapes, dtypes):
329336
node = g.make_node("Identity", inputs=[output], op_name_scope="sub_graph_ending_node",
330-
shapes=[shape], dtypes=[dtype])
337+
shapes=[shape], dtypes=[dtype], auto_infer_shape_dtype=False)
331338
new_output_names.append(node.output[0])
332339
g.outputs = new_output_names
333340
return g

tf2onnx/verbose_logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def getLogger(name=None): # pylint: disable=invalid-name, function-redefined
3232
return logger
3333

3434

35-
_SIMPLE_LOG_FORMAT = "%(message)s"
35+
_SIMPLE_LOG_FORMAT = "%(levelname)s: %(message)s"
3636
_VERBOSE_LOG_FORMAT = "%(asctime)s - %(levelname)s - %(name)s: %(message)s"
3737

3838

0 commit comments

Comments
 (0)