Skip to content

Commit 9a35357

Browse files
authored
Merge pull request #488 from nbcsm/test
re-enable some tests
2 parents a4607b3 + bd52027 commit 9a35357

File tree

5 files changed

+12
-28
lines changed

5 files changed

+12
-28
lines changed

tests/test_backend.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@
4141
_OUTPUT1 = "output1:0"
4242

4343

44-
# pylint: disable=C0111
45-
46-
4744
def make_xval(shape):
4845
x_val = np.arange(np.prod(shape)).astype("float32").reshape(shape)
4946
return x_val
@@ -54,7 +51,7 @@ def get_conv_getdata(kind=1):
5451
# generate all combinations (costly)
5552
dims = [
5653
("padding", ["SAME", "VALID"]),
57-
("input_sizes", [[32, 35, 35, 288], [32, 17, 17, 1248], [1, 28, 28, 3], [32, 8, 8, 2048]]),
54+
("input_sizes", [[32, 35, 35, 3], [32, 17, 17, 3], [1, 28, 28, 3], [32, 8, 8, 3]]),
5855
("filter_sizes", [[1, 3, 3, 1], [1, 2, 2, 1], [1, 5, 5, 1], [1, 1, 1, 1], [1, 5, 2, 1], [1, 2, 5, 1]]),
5956
("strides", [[1, 2, 2, 1], [1, 1, 1, 1]]),
6057
]
@@ -65,23 +62,23 @@ def get_conv_getdata(kind=1):
6562
elif kind == 1:
6663
# some combination to that give decent padding coverage
6764
data = [
68-
('SAME', [32, 35, 35, 288], [1, 3, 3, 1], [1, 2, 2, 1]),
69-
('SAME', [32, 35, 35, 288], [1, 2, 2, 1], [1, 2, 2, 1]),
70-
('SAME', [32, 35, 35, 288], [1, 1, 1, 1], [1, 1, 1, 1]),
71-
('SAME', [32, 35, 35, 288], [1, 5, 2, 1], [1, 2, 2, 1]),
72-
('SAME', [32, 35, 35, 288], [1, 2, 5, 1], [1, 2, 2, 1]),
73-
('SAME', [32, 35, 35, 288], [1, 2, 5, 1], [1, 1, 1, 1]),
65+
('SAME', [32, 35, 35, 3], [1, 3, 3, 1], [1, 2, 2, 1]),
66+
('SAME', [32, 35, 35, 3], [1, 2, 2, 1], [1, 2, 2, 1]),
67+
('SAME', [32, 35, 35, 3], [1, 1, 1, 1], [1, 1, 1, 1]),
68+
('SAME', [32, 35, 35, 3], [1, 5, 2, 1], [1, 2, 2, 1]),
69+
('SAME', [32, 35, 35, 3], [1, 2, 5, 1], [1, 2, 2, 1]),
70+
('SAME', [32, 35, 35, 3], [1, 2, 5, 1], [1, 1, 1, 1]),
7471
('SAME', [1, 28, 28, 3], [1, 3, 3, 1], [1, 2, 2, 1]),
7572
('SAME', [1, 28, 28, 3], [1, 3, 3, 1], [1, 1, 1, 1]),
7673
('SAME', [1, 28, 28, 3], [1, 2, 2, 1], [1, 2, 2, 1]),
7774
('SAME', [1, 28, 28, 3], [1, 2, 2, 1], [1, 1, 1, 1]),
7875
('SAME', [1, 28, 28, 3], [1, 5, 5, 1], [1, 2, 2, 1]),
7976
('SAME', [1, 28, 28, 3], [1, 5, 5, 1], [1, 1, 1, 1]),
8077
('SAME', [1, 28, 28, 3], [1, 5, 2, 1], [1, 2, 2, 1]),
81-
('SAME', [32, 8, 8, 2048], [1, 3, 3, 1], [1, 2, 2, 1]),
82-
('SAME', [32, 8, 8, 2048], [1, 3, 3, 1], [1, 1, 1, 1]),
83-
('VALID', [32, 35, 35, 288], [1, 3, 3, 1], [1, 1, 1, 1]),
84-
('VALID', [32, 35, 35, 288], [1, 2, 2, 1], [1, 2, 2, 1]),
78+
('SAME', [32, 8, 8, 3], [1, 3, 3, 1], [1, 2, 2, 1]),
79+
('SAME', [32, 8, 8, 3], [1, 3, 3, 1], [1, 1, 1, 1]),
80+
('VALID', [32, 35, 35, 3], [1, 3, 3, 1], [1, 1, 1, 1]),
81+
('VALID', [32, 35, 35, 3], [1, 2, 2, 1], [1, 2, 2, 1]),
8582
]
8683
for idx, v in enumerate(data):
8784
yield (idx,) + v
@@ -191,8 +188,6 @@ def test_maxpool(self):
191188
self.logger.debug(str(p))
192189
self._run_test_case([_OUTPUT], {_INPUT: x_val})
193190

194-
@unittest.skipIf(get_test_config().is_onnxruntime_backend and get_test_config().backend_version == "0.2.1",
195-
"onnxruntime bug")
196191
@check_onnxruntime_incompatibility("AveragePool")
197192
def test_avgpool(self):
198193
for tf_shape in ["known", "unknown"]:
@@ -1195,7 +1190,7 @@ def test_randomuniform(self):
11951190
# since results are random, compare the shapes only
11961191
self._run_test_case([_OUTPUT], {}, check_value=False, check_shape=True)
11971192

1198-
@unittest.skip("")
1193+
@unittest.skip("TF RandomUniformInt is not supported")
11991194
def test_randomuniform_int(self):
12001195
shape = tf.constant([2, 3], name="shape")
12011196
x_ = tf.random_uniform(shape, name="rand", dtype=tf.int32, maxval=10)

tests/test_gru.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import print_function
99
from __future__ import unicode_literals
1010

11-
import unittest
1211
import numpy as np
1312
import tensorflow as tf
1413

@@ -209,7 +208,6 @@ def test_single_dynamic_gru_ch_zero_state_initializer(self):
209208
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06,
210209
graph_validator=lambda g: check_gru_count(g, 1))
211210

212-
@unittest.skip("FIXME: disable for now for accuracy problem")
213211
def test_single_dynamic_gru_random_weights(self):
214212
hidden_size = 5
215213
batch_size = 1
@@ -238,7 +236,6 @@ def test_single_dynamic_gru_random_weights(self):
238236
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001,
239237
graph_validator=lambda g: check_gru_count(g, 1))
240238

241-
@unittest.skip("FIXME: disable for now for accuracy problem")
242239
def test_single_dynamic_gru_random_weights2(self):
243240
hidden_size = 128
244241
batch_size = 1

tests/test_grublock.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import print_function
99
from __future__ import unicode_literals
1010

11-
import unittest
1211
import numpy as np
1312
import tensorflow as tf
1413

@@ -194,7 +193,6 @@ def test_single_dynamic_gru_ch_zero_state_initializer(self):
194193
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06,
195194
graph_validator=lambda g: check_gru_count(g, 1))
196195

197-
@unittest.skip("FIXME: disable for now for accuracy problem")
198196
def test_single_dynamic_gru_random_weights(self):
199197
hidden_size = 5
200198
batch_size = 1
@@ -220,7 +218,6 @@ def test_single_dynamic_gru_random_weights(self):
220218
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001,
221219
graph_validator=lambda g: check_gru_count(g, 1))
222220

223-
@unittest.skip("FIXME: disable for now for accuracy problem")
224221
def test_single_dynamic_gru_random_weights2(self):
225222
hidden_size = 128
226223
batch_size = 1

tests/test_loops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import print_function
99
from __future__ import unicode_literals
1010

11-
import unittest
1211
import numpy as np
1312
import tensorflow as tf
1413

@@ -103,7 +102,6 @@ def b(i, res, res2):
103102
output_names_with_port = ["i:0", "x:0", "y:0"]
104103
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
105104

106-
@unittest.skip("bug in onnxruntime")
107105
def test_while_loop_with_ta_read_reference_outer_input_directly(self):
108106
i = tf.placeholder(tf.int32, (), name="input_1")
109107
inputs = tf.placeholder(tf.float32, (10,), name="input_2")

tests/test_lstm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import print_function
99
from __future__ import unicode_literals
1010

11-
import unittest
1211
import numpy as np
1312
import tensorflow as tf
1413

@@ -268,7 +267,6 @@ def test_single_dynamic_lstm_consume_one_of_ch_tuple(self):
268267
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
269268
graph_validator=lambda g: check_lstm_count(g, 1))
270269

271-
@unittest.skip("FIXME: disable for now for accuracy problem")
272270
def test_single_dynamic_lstm_random_weights(self, state_is_tuple=True):
273271
hidden_size = 5
274272
batch_size = 6
@@ -298,7 +296,6 @@ def test_single_dynamic_lstm_random_weights(self, state_is_tuple=True):
298296
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001,
299297
graph_validator=lambda g: check_lstm_count(g, 1))
300298

301-
@unittest.skip("FIXME: disable for now for accuracy problem")
302299
def test_single_dynamic_lstm_random_weights2(self, state_is_tuple=True):
303300
hidden_size = 128
304301
batch_size = 1

0 commit comments

Comments
 (0)