Skip to content

Commit 0f69340

Browse files
author
wayuanho
authored
Merge pull request #620 from lei-Qiao/improve_test_coverage
Improve test coverage
2 parents 9c48cfa + 4de47ee commit 0f69340

File tree

3 files changed

+23
-26
lines changed

3 files changed

+23
-26
lines changed

tests/test_backend.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -108,24 +108,17 @@ def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
108108
kwargs["constant_fold"] = False
109109
return self.run_test_case(feed_dict, [], output_names_with_port, **kwargs)
110110

111-
def _test_expand_dims(self, idx):
111+
def _test_expand_dims_known_rank(self, idx):
112112
tf.reset_default_graph()
113113
x_val = make_xval([3, 4])
114114
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
115115
op = tf.expand_dims(x, idx)
116116
_ = tf.identity(op, name=_TFOUTPUT)
117117
self._run_test_case([_OUTPUT], {_INPUT: x_val})
118118

119-
def test_expand_dims(self):
119+
def test_expand_dims_known_rank(self):
120120
for i in [-1, 0, 1, -2]:
121-
self._test_expand_dims(i)
122-
123-
def test_expand_dims_dynamic_inputs(self):
124-
x_val = make_xval([3, 4])
125-
x = tf.placeholder(tf.float32, shape=[None, None], name=_TFINPUT)
126-
op = tf.expand_dims(x, 0)
127-
_ = tf.identity(op, name=_TFOUTPUT)
128-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
121+
self._test_expand_dims_known_rank(i)
129122

130123
def test_expand_dims_one_unknown_rank(self):
131124
tf.reset_default_graph()
@@ -135,14 +128,18 @@ def test_expand_dims_one_unknown_rank(self):
135128
_ = tf.identity(op, name=_TFOUTPUT)
136129
self._run_test_case([_OUTPUT], {_INPUT: x_val})
137130

138-
def test_expand_dims_more_unknown_rank(self):
131+
def _test_expand_dims_more_unknown_rank(self, idx):
139132
tf.reset_default_graph()
140133
x_val = make_xval([3, 4])
141134
x = tf.placeholder(tf.float32, shape=[None, None], name=_TFINPUT)
142-
op = tf.expand_dims(x, 0)
135+
op = tf.expand_dims(x, idx)
143136
_ = tf.identity(op, name=_TFOUTPUT)
144137
self._run_test_case([_OUTPUT], {_INPUT: x_val})
145138

139+
def test_expand_dims_more_unknown_rank(self):
140+
for i in [-1, 0, 1, -2]:
141+
self._test_expand_dims_more_unknown_rank(i)
142+
146143
@check_opset_min_version(9, "ConstantOfShape")
147144
def test_eye_non_const1(self):
148145
# tf.eye(num_rows), num_rows is not const here
@@ -1073,6 +1070,15 @@ def test_slice(self):
10731070
_ = tf.identity(x_, name=_TFOUTPUT)
10741071
self._run_test_case([_OUTPUT], {_INPUT: x_val})
10751072

1073+
def test_slice_neg_size(self):
1074+
x_val = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32)
1075+
t1 = tf.constant([0, 1], dtype=tf.int32)
1076+
t2 = tf.constant([-1, 2], dtype=tf.int32)
1077+
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1078+
x_ = tf.slice(x0, t1, t2)
1079+
_ = tf.identity(x_, name=_TFOUTPUT)
1080+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1081+
10761082
@check_opset_min_version(10, "Slice in opset 10 can accept dymaic 'start' and 'ends'")
10771083
def test_slice_with_non_const(self):
10781084
x_val = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32)
@@ -2253,7 +2259,6 @@ def test_sparse_softmax_cross_entropy_with_logits_large_class(self):
22532259

22542260
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val}, rtol=1e-6)
22552261

2256-
@skip_onnxruntime_backend("onnxruntime Slice did not supported BOOL")
22572262
def test_matrix_band_part(self):
22582263
input_val = np.random.randint(0, 666, (10, 15)).astype(np.int32)
22592264
input_x = tf.placeholder(dtype=tf.int32, shape=[None, None], name=_TFINPUT)
@@ -2263,7 +2268,6 @@ def test_matrix_band_part(self):
22632268
_ = tf.identity(res1, name=_TFOUTPUT1)
22642269
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: input_val})
22652270

2266-
@skip_onnxruntime_backend("onnxruntime Slice did not supported BOOL.")
22672271
def test_matrix_band_part_2(self):
22682272
input_val = np.random.randint(0, 666, (1, 1)).astype(np.int32)
22692273
input_x = tf.placeholder(dtype=tf.int32, shape=[None, None], name=_TFINPUT)
@@ -2432,7 +2436,7 @@ def test_softsign(self):
24322436

24332437
def test_batch_to_spacend(self):
24342438
block_size = [2, 2]
2435-
crop = [[0, 1], [2, 1]]
2439+
crop = [[1, 0], [2, 1]]
24362440

24372441
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)
24382442
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC

tf2onnx/onnx_opset/math.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,9 @@ def version_1(cls, ctx, node, **kwargs):
251251
# ONNX: Each input value is divided by (bias+(alpha/size)*sum(xi^2 for every xi in the local region))^beta
252252
# TF: sqr_sum[a, b, c, d] = sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
253253
# output = input / (bias + alpha * sqr_sum) ** beta
254-
depth_radius = node.get_attr("depth_radius")
255-
if depth_radius:
256-
size = depth_radius.i * 2 + 1
257-
else:
258-
# by default, depth_radius is 5 in tensorflow
259-
size = 5 * 2 + 1
254+
255+
# by default, depth_radius is 5 in tensorflow
256+
size = node.get_attr_value("depth_radius", 5) * 2 + 1
260257

261258
node.set_attr("size", size)
262259
node.set_attr("alpha", size * node.get_attr("alpha").f)

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,11 +1278,7 @@ def version_10(cls, ctx, node, **kwargs):
12781278
seq_dim = node.get_attr("seq_dim")
12791279
utils.make_sure(seq_dim is not None, "sequence dim must be given in {}".format(node.name))
12801280
seq_dim = seq_dim.i
1281-
batch_dim = node.get_attr("batch_dim")
1282-
if batch_dim is not None:
1283-
batch_dim = batch_dim.i
1284-
else:
1285-
batch_dim = 0
1281+
batch_dim = node.get_attr_value("batch_dim", 0)
12861282

12871283
ctx.remove_node(node.name)
12881284
node = ctx.make_node(

0 commit comments

Comments
 (0)