Skip to content

Commit e8ec761

Browse files
authored
Merge pull request #432 from onnx/gs/fix-squeeze
make squeeze op more robust, turn down logging for some warnings
2 parents 2f577d5 + 760e4ab commit e8ec761

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

tf2onnx/graph.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,8 @@ def attr_onnx(self):
9090
"""Return onnx valid attributes"""
9191
schema = get_schema(self.type, self.graph.opset, self.domain)
9292
if schema is None and not (self.is_const() or self.is_graph_input()):
93-
log.warning("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check", self.name, self.domain,
94-
self.type)
95-
93+
log.debug("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check",
94+
self.name, self.domain, self.type)
9695
onnx_attrs = {}
9796
for a in self._attr.values():
9897
if schema is None or schema.has_attribute(a.name):

tf2onnx/tfonnx.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,15 @@ def squeeze_op(ctx, node, name, args):
269269
del node.attr["axis"]
270270

271271
shape = ctx.get_shape(node.input[0])
272-
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
273-
shape_len = len(shape)
274272
if axis and axis.ints:
275273
axis = axis.ints
276-
axis = [a + shape_len if a < 0 else a for a in axis]
274+
neg_axis = any([val < 0 for val in axis])
275+
if neg_axis:
276+
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
277+
shape_len = len(shape)
278+
axis = [a + shape_len if a < 0 else a for a in axis]
277279
else:
280+
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
278281
axis = [i for i, j in enumerate(shape) if j == 1]
279282
node.set_attr("axes", axis)
280283

@@ -450,7 +453,7 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
450453
output_shape = spatial_map(output_shape, NHWC_TO_NCHW)
451454
# calculate pads
452455
if any(input_shape[i + 2] == -1 for i in range(spatial)):
453-
log.warning("node %s has unknown dim %s for pads calculation, fallback to auto_pad" % (
456+
log.debug("node %s has unknown dim %s for pads calculation, fallback to auto_pad" % (
454457
node.name, str(input_shape)))
455458
node.set_attr("auto_pad", "SAME_UPPER")
456459
else:
@@ -1207,7 +1210,6 @@ def minmax_op(ctx, node, name, args):
12071210
# get a tensor with zeros (since there is no Fill op as of opset8)
12081211
sub_node = ctx.make_node("Sub", [has_correct_shape, has_correct_shape],
12091212
op_name_scope=input_node.name)
1210-
12111213
# use add as 'broadcast' op
12121214
add_node = ctx.make_node("Add", [input_node.output[0], sub_node.output[0]],
12131215
op_name_scope=input_node.name)

0 commit comments

Comments
 (0)