Skip to content

Commit a2c46bd

Browse files
committed
refactor
1 parent c0581f6 commit a2c46bd

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

tests/common.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
""" test common utilities."""
55

66
import argparse
7-
import numpy as np
87
import os
98
import sys
109
import unittest
1110
from collections import defaultdict
1211

1312
from distutils.version import LooseVersion
1413
from parameterized import parameterized
14+
import numpy as np
1515
from tf2onnx import constants, logging, utils
1616

1717
__all__ = [
@@ -281,11 +281,8 @@ def check_onnxruntime_incompatibility(op):
281281
def validate_const_node(node, expected_val):
282282
if node.is_const():
283283
node_val = node.get_tensor_value()
284-
if (isinstance(expected_val, list) and isinstance(expected_val[0], float)) \
285-
or isinstance(expected_val, float):
286-
np.testing.assert_allclose(expected_val, node_val)
287-
return True
288-
return node_val == expected_val
284+
np.testing.assert_allclose(expected_val, node_val)
285+
return True
289286
return False
290287

291288

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
150150
input_shape = spatial_map(input_shape, constants.NHWC_TO_NCHW)
151151
output_shape = spatial_map(output_shape, constants.NHWC_TO_NCHW)
152152
# calculate pads
153-
if any(input_shape[i + 2] == -1 or output_shape[i + 2] == -1 for i in range(spatial)) \
154-
or any(output_shape[i + 2] == -1 for i in range(spatial)):
153+
if any(input_shape[i + 2] == -1 or output_shape[i + 2] == -1 for i in range(spatial)):
155154
logger.debug("node %s has unknown dim %d for pads calculation, fallback to auto_pad",
156155
node.name, input_shape)
157156
node.set_attr("auto_pad", "SAME_UPPER")

tf2onnx/shape_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def infer_shape_for_node(g, node):
121121
# https://www.tensorflow.org/api_docs/python/tf/gather
122122
shape_params = g.get_shape(node.input[0])
123123
shape_indices = g.get_shape(node.input[1])
124-
# in lower tf version, gather only has 2 inputs
124+
# gather can only have 2 inputs
125+
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/gather.html
125126
if len(node.input) == 3:
126127
axis = node.input[2].get_tensor_value()
127128
else:
@@ -218,7 +219,7 @@ def infer_output_shapes_with_partial_inputs(g, node):
218219
data_inputs = node.input[:-1]
219220
input_shapes = [g.get_shape(node) for node in data_inputs]
220221
input_shapes = [shape for shape in input_shapes if shape is not None]
221-
if len(input_shapes) == 0:
222+
if not input_shapes:
222223
logger.debug("all input shapes of concat node %s are None, can't infer its output shape", node.name)
223224
return False
224225

0 commit comments

Comments
 (0)