Skip to content

Commit 9738fed

Browse files
resolve comments
1 parent 3632e8f commit 9738fed

File tree

5 files changed

+58
-49
lines changed

5 files changed

+58
-49
lines changed

tests/test_shape_inference.py renamed to tests/test_onnx_shape_inference.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,12 @@ class ShapeInferenceTests(Tf2OnnxBackendTestBase):
3131
def _run_test_case(self, graph, feed_dict):
3232
"""Run model with onnxruntime and compare results' shape with internal shape inference."""
3333
outputs = graph.outputs
34-
results = []
35-
raised = False
36-
try:
37-
results = self.run_backend(graph, outputs, feed_dict)
38-
except Exception as ex: # pylint: disable=broad-except
39-
self.logger.error(ex)
40-
raised = True
41-
self.assertFalse(raised)
42-
43-
raised = False
34+
results = self.run_backend(graph, outputs, feed_dict)
35+
4436
for actual, inferred in zip(results, outputs):
4537
actual_shape = actual.shape
4638
inferred_shape = tuple(graph.get_shape(inferred))
47-
try:
48-
utils.merge_shapes(actual_shape, inferred_shape)
49-
except Exception as ex: # pylint: disable=broad-except
50-
self.logger.error(ex)
51-
raised = True
52-
self.assertFalse(raised)
39+
self.assertTrue(utils.are_shapes_compatible(actual_shape, inferred_shape))
5340

5441
actual_dtype = actual.dtype
5542
inferred_dtype = utils.ONNX_TO_NUMPY_DTYPE[graph.get_dtype(inferred)]
@@ -418,6 +405,9 @@ def test_override_shape(self):
418405
output_name = utils.make_name("output")
419406
graph._output_shapes[output_name] = [-1, -1, 2, 3] # pylint: disable=protected-access
420407
node = graph.make_node("Transpose", [INPUT1], attr={"perm": [1, 0, 2, 3]}, outputs=[output_name])
408+
409+
graph.update_node_shape_dtype(node, override=True)
410+
421411
graph.add_graph_output(node.output[0])
422412
self._run_test_case(graph, self._generate_random_inputs(inputs, shapes, dtypes))
423413

tf2onnx/graph.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
import collections
1313
import copy
1414
import logging
15-
import traceback
1615
import six
1716
import numpy as np
1817

1918
from onnx import helper, numpy_helper, shape_inference, OperatorSetIdProto, AttributeProto, TensorProto
20-
from tf2onnx import constants
2119
from tf2onnx import utils, __version__
2220
from tf2onnx.utils import port_name, find_opset
2321
from tf2onnx import optimizer
@@ -421,7 +419,7 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
421419
return node
422420

423421
def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, skip_conversion=True,
424-
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=None, auto_infer_shape_dtype=True):
422+
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=None, infer_shape_dtype=True):
425423
"""Make a new onnx node in the graph"""
426424
if attr is None:
427425
attr = {}
@@ -476,8 +474,8 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
476474
for i in range(output_count):
477475
self.set_dtype(node.output[i], dtypes[i])
478476

479-
if (not shapes or not dtypes) and auto_infer_shape_dtype:
480-
self.update_node_shape_dtype(node, override=True)
477+
if (not shapes or not dtypes) and infer_shape_dtype:
478+
self.update_node_shape_dtype(node, override=False)
481479

482480
self._nodes.append(node)
483481
return node
@@ -538,44 +536,42 @@ def reset_nodes(self, ops):
538536
self._dtypes = remained_dtypes
539537
self._output_shapes = remained_shapes
540538

541-
def update_node_shape_dtype(self, node, override=True):
542-
"""try the best to infer shapes and dtypes for outputs of the node"""
539+
def update_node_shape_dtype(self, node, override=False):
540+
"""Try the best to infer shapes and dtypes for outputs of the node,
541+
by default, we respect TF shapes and dtypes.
542+
"""
543543
if node.is_const() or node.is_graph_input():
544544
return
545545
# NOTE: only support onnx node for now
546-
if node.domain != constants.ONNX_DOMAIN:
546+
if not utils.is_onnx_domain(node.domain):
547547
return
548548

549549
logger.debug("Infer shape and dtype for [%s]", node.name)
550550
# NOTE: shape inference for some ops need the input values of the op, e.g., Reshape
551551
# op needs the "Shape" value to infer output shape.
552-
initializer = []
552+
initializers = []
553553
for i, inp in enumerate(node.inputs):
554554
if not inp:
555-
logger.warning("[%s] infer a inexistent node: [%s], please check the code", node.name, node.input[i])
555+
if logger.isEnabledFor(logging.VERBOSE):
556+
logger.warning(
557+
"[%s] infer a inexistent node: [%s], please check the code",
558+
node.name, node.input[i]
559+
)
556560
continue
557561
if inp.is_const():
558562
t = inp.get_attr("value")
559563
tensor = helper.get_attribute_value(t)
560564
tensor.name = inp.output[0]
561-
initializer.append(tensor)
565+
initializers.append(tensor)
562566

563567
input_shapes = [self.get_shape(i) for i in node.input]
564568
input_dtypes = [self.get_dtype(i) for i in node.input]
565569

566-
dtypes = {}
567-
shapes = {}
568-
try:
569-
shapes, dtypes = infer_onnx_shape_dtype(node, input_shapes, input_dtypes, self._opset, initializer)
570-
except Exception:
571-
tb = traceback.format_exc()
572-
logger.warning("ONNX Failed to infer shapes and dtypes for [%s, type: %s]", node.name, node.type)
573-
logger.warning("Inference error: %s", tb)
570+
shapes, dtypes = infer_onnx_shape_dtype(node, self._opset, input_shapes, input_dtypes, initializers)
571+
if not shapes or not dtypes:
574572
return
575573

576-
for output in node.output:
577-
dtype = dtypes[output]
578-
shape = shapes[output]
574+
for output, shape, dtype in zip(node.output, shapes, dtypes):
579575
if dtype == TensorProto.UNDEFINED:
580576
logger.debug("Inferred dtype for [%s, type: %s] is UNDEFINED, SKIP", node.name, node.type)
581577
else:

tf2onnx/schemas.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
from __future__ import print_function
1010
from __future__ import unicode_literals
1111

12+
import logging
1213
import copy
1314
from collections import defaultdict, OrderedDict
1415
from onnx import defs, helper, TensorProto, OperatorSetIdProto, shape_inference
1516

1617
from . import constants
1718
from . import utils
1819

20+
logger = logging.getLogger(__name__)
21+
1922

2023
class OnnxOpSchema(object):
2124
"""Wrapper for Onnx schema."""
@@ -116,7 +119,7 @@ def get_max_supported_opset_version(domain=None):
116119
return _domain_opset_versions.get(domain, None)
117120

118121

119-
def infer_onnx_shape_dtype(node, input_shapes, input_dtypes, opset, initializer=None):
122+
def infer_onnx_shape_dtype(node, opset_version, input_shapes, input_dtypes, initializers=None):
120123
"""
121124
Infer shapes and dtypes for outputs of the node.
122125
Sometimes, shape inference needs the values of node's inputs, so initializers are used.
@@ -138,20 +141,29 @@ def build_onnx_op(node):
138141
onnx_node.attribute.extend(attr)
139142
return onnx_node
140143

141-
shapes = {}
142-
dtypes = {}
143144
inputs = []
144145
outputs = []
145146
for inp, shape, dtype in zip(node.input, input_shapes, input_dtypes):
146147
inputs.append(utils.make_onnx_inputs_outputs(inp, dtype, shape))
147148
for output in node.output:
148149
outputs.append(utils.make_onnx_inputs_outputs(output, TensorProto.UNDEFINED, None))
149-
graph_def = helper.make_graph([build_onnx_op(node)], "infer-graph", inputs, outputs, initializer=initializer)
150+
graph_proto = helper.make_graph([build_onnx_op(node)], "infer-graph", inputs, outputs, initializer=initializers)
150151
imp = OperatorSetIdProto()
151-
imp.version = opset
152-
model_def = helper.make_model(graph_def, opset_imports=[imp])
152+
imp.version = opset_version
153+
model_proto = helper.make_model(graph_proto, opset_imports=[imp])
154+
155+
inferred_model = None
156+
try:
157+
inferred_model = shape_inference.infer_shapes(model_proto)
158+
except Exception: # pylint: disable=broad-except
159+
logger.warning(
160+
"ONNX Failed to infer shapes and dtypes for [%s, type: %s]",
161+
node.name, node.type, exc_info=1
162+
)
163+
return None, None
153164

154-
inferred_model = shape_inference.infer_shapes(model_def)
165+
shapes = {}
166+
dtypes = {}
155167
for output in inferred_model.graph.output:
156168
tensor_type = output.type.tensor_type
157169
if tensor_type.HasField("elem_type"):
@@ -165,4 +177,15 @@ def build_onnx_op(node):
165177
]
166178
else:
167179
shapes[output.name] = None
168-
return shapes, dtypes
180+
output_shapes = []
181+
output_dtypes = []
182+
for output in node.output:
183+
if output in shapes:
184+
output_shapes.append(shapes[output])
185+
else:
186+
output_shapes.append(None)
187+
if output in dtypes:
188+
output_dtypes.append(dtypes[output])
189+
else:
190+
output_dtypes.append(TensorProto.UNDEFINED)
191+
return output_shapes, output_dtypes

tf2onnx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def construct_graph_from_nodes(parent_g, nodes, outputs, shapes, dtypes):
315315
all_outputs |= set(op.output)
316316

317317
new_node = g.make_node(op.type, op.input, outputs=op.output, attr=op.attr, name=op.name,
318-
skip_conversion=op.skip_conversion, auto_infer_shape_dtype=False)
318+
skip_conversion=op.skip_conversion, infer_shape_dtype=False)
319319
body_graphs = op.graph.contained_graphs.pop(op.name, None)
320320
if body_graphs:
321321
for attr_name, body_graph in body_graphs.items():
@@ -334,7 +334,7 @@ def construct_graph_from_nodes(parent_g, nodes, outputs, shapes, dtypes):
334334
new_output_names = []
335335
for output, shape, dtype in zip(outputs, shapes, dtypes):
336336
node = g.make_node("Identity", inputs=[output], op_name_scope="sub_graph_ending_node",
337-
shapes=[shape], dtypes=[dtype], auto_infer_shape_dtype=False)
337+
shapes=[shape], dtypes=[dtype], infer_shape_dtype=False)
338338
new_output_names.append(node.output[0])
339339
g.outputs = new_output_names
340340
return g

tf2onnx/verbose_logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def getLogger(name=None): # pylint: disable=invalid-name, function-redefined
3232
return logger
3333

3434

35-
_SIMPLE_LOG_FORMAT = "%(levelname)s: %(message)s"
35+
_SIMPLE_LOG_FORMAT = "%(message)s"
3636
_VERBOSE_LOG_FORMAT = "%(asctime)s - %(levelname)s - %(name)s: %(message)s"
3737

3838

0 commit comments

Comments
 (0)