Skip to content

Commit cd34b38

Browse files
authored
Merge pull request #347 from zhijxu-MS/tmp_branch_for_PR3
replace conversion logic of "select" with a simpler one
2 parents 6d9697a + 9db30fb commit cd34b38

File tree

6 files changed

+64
-259
lines changed

6 files changed

+64
-259
lines changed

tests/test_backend.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,8 +1446,7 @@ def test_reverse_sequence_time_major(self):
14461446
_ = tf.identity(x_, name=_TFOUTPUT)
14471447
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14481448

1449-
# @unittest.skipIf(OPSET < 8, "supported with opset 8 or better")
1450-
@unittest.skip("FIXME: the newest onnxruntime wheel hasn't been published to PYPI, so Select op is not supported")
1449+
@check_opset_min_version(7, "where")
14511450
def test_where(self):
14521451
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.int32)
14531452
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],
@@ -1459,7 +1458,7 @@ def test_where(self):
14591458
_ = tf.identity(picks, name=_TFOUTPUT)
14601459
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14611460

1462-
@check_opset_min_version(8, "where")
1461+
@check_opset_min_version(7, "where")
14631462
def test_where_with_two_rank_input(self):
14641463
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.int32)
14651464
true_result = np.array([[111, 111], [222, 222], [333, 333], [444, 444], [555, 555],
@@ -1475,7 +1474,7 @@ def test_where_with_two_rank_input(self):
14751474

14761475
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14771476

1478-
@check_opset_min_version(8, "where")
1477+
@check_opset_min_version(7, "where")
14791478
def test_where_with_two_rank_condition(self):
14801479
x_val = np.array([[1, 2, -3, 4, -5, -6, -7, 8, 9, 0]], dtype=np.int32)
14811480
true_result = np.array([[111, 222, 333, 444, 555, 666, 777, 888, 999, 1000]],
@@ -1488,7 +1487,7 @@ def test_where_with_two_rank_condition(self):
14881487

14891488
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14901489

1491-
@check_opset_min_version(8, "where")
1490+
@check_opset_min_version(7, "where")
14921491
def test_where_with_three_rank_condition(self):
14931492
x_val = np.array([[[1, 2, -3, 4, -5, -6, -7, 8, 9, 0]]], dtype=np.int32)
14941493
true_result = np.array([[[111, 222, 333, 444, 555, 666, 777, 888, 999, 1000]]],
@@ -1501,7 +1500,7 @@ def test_where_with_three_rank_condition(self):
15011500

15021501
self._run_test_case([_OUTPUT], {_INPUT: x_val})
15031502

1504-
@check_opset_min_version(8, "where")
1503+
@check_opset_min_version(7, "where")
15051504
def test_where_scalar(self):
15061505
x_val = np.array(6, dtype=np.int32)
15071506
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],

tests/test_gru.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_single_dynamic_gru(self):
4747
input_names_with_port = ["input_1:0"]
4848
feed_dict = {"input_1:0": x_val}
4949
output_names_with_port = ["output:0", "cell_state:0"]
50-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
50+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
5151

5252
def test_multiple_dynamic_gru(self):
5353
units = 5
@@ -93,7 +93,7 @@ def test_multiple_dynamic_gru(self):
9393
feed_dict = {"input_1:0": x_val}
9494
input_names_with_port = ["input_1:0"]
9595
output_names_with_port = ["output:0", "cell_state:0"]
96-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
96+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
9797

9898
def test_single_dynamic_gru_seq_length_is_const(self):
9999
units = 5
@@ -119,7 +119,7 @@ def test_single_dynamic_gru_seq_length_is_const(self):
119119
feed_dict = {"input_1:0": x_val}
120120
input_names_with_port = ["input_1:0"]
121121
output_names_with_port = ["output:0", "cell_state:0"]
122-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
122+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
123123

124124
def test_single_dynamic_gru_seq_length_is_not_const(self):
125125
units = 5
@@ -148,7 +148,7 @@ def test_single_dynamic_gru_seq_length_is_not_const(self):
148148
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
149149
input_names_with_port = ["input_1:0", "input_2:0"]
150150
output_names_with_port = ["output:0", "cell_state:0"]
151-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
151+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
152152

153153
def test_single_dynamic_gru_placeholder_input(self):
154154
units = 5
@@ -172,7 +172,7 @@ def test_single_dynamic_gru_placeholder_input(self):
172172
feed_dict = {"input_1:0": x_val}
173173
input_names_with_port = ["input_1:0"]
174174
output_names_with_port = ["output:0", "cell_state:0"]
175-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
175+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
176176

177177
def test_single_dynamic_gru_ch_zero_state_initializer(self):
178178
units = 5
@@ -201,7 +201,7 @@ def test_single_dynamic_gru_ch_zero_state_initializer(self):
201201
feed_dict = {"input_1:0": x_val}
202202
input_names_with_port = ["input_1:0"]
203203
output_names_with_port = ["output:0", "cell_state:0"]
204-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
204+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
205205

206206
@unittest.skip("FIXME: disable for now for accuracy problem")
207207
def test_single_dynamic_gru_random_weights(self):
@@ -304,7 +304,7 @@ def test_dynamic_gru_state_consumed_only(self):
304304
feed_dict = {"input_1:0": x_val}
305305
input_names_with_port = ["input_1:0"]
306306
output_names_with_port = ["cell_state:0"]
307-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001, atol=1e-07)
307+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001, atol=1e-06)
308308

309309
def test_dynamic_bigru(self):
310310
units = 5
@@ -335,7 +335,7 @@ def test_dynamic_bigru(self):
335335
feed_dict = {"input_1:0": x_val}
336336
input_names_with_port = ["input_1:0"]
337337
output_names_with_port = ["output:0", "cell_state:0"]
338-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
338+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
339339

340340
def test_dynamic_bigru_output_consumed_only(self):
341341
units = 5
@@ -365,7 +365,7 @@ def test_dynamic_bigru_output_consumed_only(self):
365365
feed_dict = {"input_1:0": x_val}
366366
input_names_with_port = ["input_1:0"]
367367
output_names_with_port = ["output:0"]
368-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
368+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
369369

370370
def test_dynamic_bigru_state_consumed_only(self):
371371
units = 5
@@ -395,7 +395,7 @@ def test_dynamic_bigru_state_consumed_only(self):
395395
feed_dict = {"input_1:0": x_val}
396396
input_names_with_port = ["input_1:0"]
397397
output_names_with_port = ["cell_state:0"]
398-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
398+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
399399

400400
def test_dynamic_bidirectional_but_one_gru(self):
401401
units = 5
@@ -423,7 +423,7 @@ def test_dynamic_bidirectional_but_one_gru(self):
423423
feed_dict = {"input_1:0": x_val}
424424
input_names_with_port = ["input_1:0"]
425425
output_names_with_port = ["output:0", "cell_state:0"]
426-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
426+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
427427

428428
def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
429429
units = 5
@@ -448,7 +448,7 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
448448
feed_dict = {"input_1:0": x_val}
449449
input_names_with_port = ["input_1:0"]
450450
output_names_with_port = ["output:0"]
451-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
451+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
452452

453453
def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
454454
units = 5
@@ -473,7 +473,7 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
473473
feed_dict = {"input_1:0": x_val}
474474
input_names_with_port = ["input_1:0"]
475475
output_names_with_port = ["cell_state:0"]
476-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
476+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
477477

478478

479479
if __name__ == '__main__':

tests/test_grublock.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_single_dynamic_gru(self):
4545
input_names_with_port = ["input_1:0"]
4646
feed_dict = {"input_1:0": x_val}
4747
output_names_with_port = ["output:0", "cell_state:0"]
48-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
48+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
4949

5050
def test_multiple_dynamic_gru(self):
5151
units = 5
@@ -89,7 +89,7 @@ def test_multiple_dynamic_gru(self):
8989
feed_dict = {"input_1:0": x_val}
9090
input_names_with_port = ["input_1:0"]
9191
output_names_with_port = ["output:0", "cell_state:0"]
92-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
92+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
9393

9494
def test_single_dynamic_gru_seq_length_is_const(self):
9595
units = 5
@@ -113,7 +113,7 @@ def test_single_dynamic_gru_seq_length_is_const(self):
113113
feed_dict = {"input_1:0": x_val}
114114
input_names_with_port = ["input_1:0"]
115115
output_names_with_port = ["output:0", "cell_state:0"]
116-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
116+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
117117

118118
def test_single_dynamic_gru_seq_length_is_not_const(self):
119119
units = 5
@@ -140,7 +140,7 @@ def test_single_dynamic_gru_seq_length_is_not_const(self):
140140
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
141141
input_names_with_port = ["input_1:0", "input_2:0"]
142142
output_names_with_port = ["output:0", "cell_state:0"]
143-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
143+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
144144

145145
def test_single_dynamic_gru_placeholder_input(self):
146146
units = 5
@@ -162,7 +162,7 @@ def test_single_dynamic_gru_placeholder_input(self):
162162
feed_dict = {"input_1:0": x_val}
163163
input_names_with_port = ["input_1:0"]
164164
output_names_with_port = ["output:0", "cell_state:0"]
165-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
165+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
166166

167167
def test_single_dynamic_gru_ch_zero_state_initializer(self):
168168
units = 5
@@ -188,7 +188,7 @@ def test_single_dynamic_gru_ch_zero_state_initializer(self):
188188
feed_dict = {"input_1:0": x_val}
189189
input_names_with_port = ["input_1:0"]
190190
output_names_with_port = ["output:0", "cell_state:0"]
191-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
191+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
192192

193193
@unittest.skip("FIXME: disable for now for accuracy problem")
194194
def test_single_dynamic_gru_random_weights(self):
@@ -310,7 +310,7 @@ def test_dynamic_bigru(self):
310310
feed_dict = {"input_1:0": x_val}
311311
input_names_with_port = ["input_1:0"]
312312
output_names_with_port = ["output:0", "cell_state:0"]
313-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
313+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
314314

315315
def test_dynamic_bigru_output_consumed_only(self):
316316
units = 5
@@ -337,7 +337,7 @@ def test_dynamic_bigru_output_consumed_only(self):
337337
feed_dict = {"input_1:0": x_val}
338338
input_names_with_port = ["input_1:0"]
339339
output_names_with_port = ["output:0"]
340-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
340+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
341341

342342
def test_dynamic_bigru_state_consumed_only(self):
343343
units = 5
@@ -364,7 +364,7 @@ def test_dynamic_bigru_state_consumed_only(self):
364364
feed_dict = {"input_1:0": x_val}
365365
input_names_with_port = ["input_1:0"]
366366
output_names_with_port = ["cell_state:0"]
367-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
367+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
368368

369369
def test_dynamic_bidirectional_but_one_gru(self):
370370
units = 5
@@ -390,7 +390,7 @@ def test_dynamic_bidirectional_but_one_gru(self):
390390
feed_dict = {"input_1:0": x_val}
391391
input_names_with_port = ["input_1:0"]
392392
output_names_with_port = ["output:0", "cell_state:0"]
393-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
393+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
394394

395395
def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
396396
units = 5
@@ -440,7 +440,7 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
440440
feed_dict = {"input_1:0": x_val}
441441
input_names_with_port = ["input_1:0"]
442442
output_names_with_port = ["cell_state:0"]
443-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
443+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
444444

445445

446446
if __name__ == '__main__':

tf2onnx/function/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .gathernd import gathernd_op
1010
from .matrixbandpart import matrixbandpart_op
1111
from .range import range_op7
12-
from .select import select_op8
12+
from .select import select_op7
1313
from .sparse_softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op
1414

15-
__all__ = ["gathernd_op", "matrixbandpart_op", "range_op7", "select_op8", "sparse_softmax_cross_entropy_with_logits_op"]
15+
__all__ = ["gathernd_op", "matrixbandpart_op", "range_op7", "select_op7", "sparse_softmax_cross_entropy_with_logits_op"]

0 commit comments

Comments
 (0)