Skip to content

Commit 368843a

Browse files
committed
release accuracy check of gru tests by setting atol to 1e-6
1 parent 0b4ce93 commit 368843a

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

tests/test_gru.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_single_dynamic_gru(self):
4747
input_names_with_port = ["input_1:0"]
4848
feed_dict = {"input_1:0": x_val}
4949
output_names_with_port = ["output:0", "cell_state:0"]
50-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
50+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
5151

5252
def test_multiple_dynamic_gru(self):
5353
units = 5
@@ -93,7 +93,7 @@ def test_multiple_dynamic_gru(self):
9393
feed_dict = {"input_1:0": x_val}
9494
input_names_with_port = ["input_1:0"]
9595
output_names_with_port = ["output:0", "cell_state:0"]
96-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
96+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
9797

9898
def test_single_dynamic_gru_seq_length_is_const(self):
9999
units = 5
@@ -119,7 +119,7 @@ def test_single_dynamic_gru_seq_length_is_const(self):
119119
feed_dict = {"input_1:0": x_val}
120120
input_names_with_port = ["input_1:0"]
121121
output_names_with_port = ["output:0", "cell_state:0"]
122-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
122+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
123123

124124
def test_single_dynamic_gru_seq_length_is_not_const(self):
125125
units = 5
@@ -148,7 +148,7 @@ def test_single_dynamic_gru_seq_length_is_not_const(self):
148148
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
149149
input_names_with_port = ["input_1:0", "input_2:0"]
150150
output_names_with_port = ["output:0", "cell_state:0"]
151-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
151+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
152152

153153
def test_single_dynamic_gru_placeholder_input(self):
154154
units = 5
@@ -172,7 +172,7 @@ def test_single_dynamic_gru_placeholder_input(self):
172172
feed_dict = {"input_1:0": x_val}
173173
input_names_with_port = ["input_1:0"]
174174
output_names_with_port = ["output:0", "cell_state:0"]
175-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
175+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
176176

177177
def test_single_dynamic_gru_ch_zero_state_initializer(self):
178178
units = 5
@@ -201,7 +201,7 @@ def test_single_dynamic_gru_ch_zero_state_initializer(self):
201201
feed_dict = {"input_1:0": x_val}
202202
input_names_with_port = ["input_1:0"]
203203
output_names_with_port = ["output:0", "cell_state:0"]
204-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
204+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
205205

206206
@unittest.skip("FIXME: disable for now for accuracy problem")
207207
def test_single_dynamic_gru_random_weights(self):
@@ -304,7 +304,7 @@ def test_dynamic_gru_state_consumed_only(self):
304304
feed_dict = {"input_1:0": x_val}
305305
input_names_with_port = ["input_1:0"]
306306
output_names_with_port = ["cell_state:0"]
307-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001, atol=1e-07)
307+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001, atol=1e-06)
308308

309309
def test_dynamic_bigru(self):
310310
units = 5
@@ -335,7 +335,7 @@ def test_dynamic_bigru(self):
335335
feed_dict = {"input_1:0": x_val}
336336
input_names_with_port = ["input_1:0"]
337337
output_names_with_port = ["output:0", "cell_state:0"]
338-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
338+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
339339

340340
def test_dynamic_bigru_output_consumed_only(self):
341341
units = 5
@@ -365,7 +365,7 @@ def test_dynamic_bigru_output_consumed_only(self):
365365
feed_dict = {"input_1:0": x_val}
366366
input_names_with_port = ["input_1:0"]
367367
output_names_with_port = ["output:0"]
368-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
368+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
369369

370370
def test_dynamic_bigru_state_consumed_only(self):
371371
units = 5
@@ -395,7 +395,7 @@ def test_dynamic_bigru_state_consumed_only(self):
395395
feed_dict = {"input_1:0": x_val}
396396
input_names_with_port = ["input_1:0"]
397397
output_names_with_port = ["cell_state:0"]
398-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
398+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
399399

400400
def test_dynamic_bidirectional_but_one_gru(self):
401401
units = 5
@@ -423,7 +423,7 @@ def test_dynamic_bidirectional_but_one_gru(self):
423423
feed_dict = {"input_1:0": x_val}
424424
input_names_with_port = ["input_1:0"]
425425
output_names_with_port = ["output:0", "cell_state:0"]
426-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
426+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
427427

428428
def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
429429
units = 5
@@ -448,7 +448,7 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
448448
feed_dict = {"input_1:0": x_val}
449449
input_names_with_port = ["input_1:0"]
450450
output_names_with_port = ["output:0"]
451-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
451+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
452452

453453
def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
454454
units = 5
@@ -473,7 +473,7 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
473473
feed_dict = {"input_1:0": x_val}
474474
input_names_with_port = ["input_1:0"]
475475
output_names_with_port = ["cell_state:0"]
476-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
476+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
477477

478478

479479
if __name__ == '__main__':

tests/test_grublock.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_single_dynamic_gru(self):
4545
input_names_with_port = ["input_1:0"]
4646
feed_dict = {"input_1:0": x_val}
4747
output_names_with_port = ["output:0", "cell_state:0"]
48-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
48+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
4949

5050
def test_multiple_dynamic_gru(self):
5151
units = 5
@@ -89,7 +89,7 @@ def test_multiple_dynamic_gru(self):
8989
feed_dict = {"input_1:0": x_val}
9090
input_names_with_port = ["input_1:0"]
9191
output_names_with_port = ["output:0", "cell_state:0"]
92-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
92+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
9393

9494
def test_single_dynamic_gru_seq_length_is_const(self):
9595
units = 5
@@ -113,7 +113,7 @@ def test_single_dynamic_gru_seq_length_is_const(self):
113113
feed_dict = {"input_1:0": x_val}
114114
input_names_with_port = ["input_1:0"]
115115
output_names_with_port = ["output:0", "cell_state:0"]
116-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
116+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
117117

118118
def test_single_dynamic_gru_seq_length_is_not_const(self):
119119
units = 5
@@ -140,7 +140,7 @@ def test_single_dynamic_gru_seq_length_is_not_const(self):
140140
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
141141
input_names_with_port = ["input_1:0", "input_2:0"]
142142
output_names_with_port = ["output:0", "cell_state:0"]
143-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
143+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
144144

145145
def test_single_dynamic_gru_placeholder_input(self):
146146
units = 5
@@ -162,7 +162,7 @@ def test_single_dynamic_gru_placeholder_input(self):
162162
feed_dict = {"input_1:0": x_val}
163163
input_names_with_port = ["input_1:0"]
164164
output_names_with_port = ["output:0", "cell_state:0"]
165-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
165+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
166166

167167
def test_single_dynamic_gru_ch_zero_state_initializer(self):
168168
units = 5
@@ -188,7 +188,7 @@ def test_single_dynamic_gru_ch_zero_state_initializer(self):
188188
feed_dict = {"input_1:0": x_val}
189189
input_names_with_port = ["input_1:0"]
190190
output_names_with_port = ["output:0", "cell_state:0"]
191-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03)
191+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
192192

193193
@unittest.skip("FIXME: disable for now for accuracy problem")
194194
def test_single_dynamic_gru_random_weights(self):
@@ -310,7 +310,7 @@ def test_dynamic_bigru(self):
310310
feed_dict = {"input_1:0": x_val}
311311
input_names_with_port = ["input_1:0"]
312312
output_names_with_port = ["output:0", "cell_state:0"]
313-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
313+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
314314

315315
def test_dynamic_bigru_output_consumed_only(self):
316316
units = 5
@@ -337,7 +337,7 @@ def test_dynamic_bigru_output_consumed_only(self):
337337
feed_dict = {"input_1:0": x_val}
338338
input_names_with_port = ["input_1:0"]
339339
output_names_with_port = ["output:0"]
340-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
340+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
341341

342342
def test_dynamic_bigru_state_consumed_only(self):
343343
units = 5
@@ -364,7 +364,7 @@ def test_dynamic_bigru_state_consumed_only(self):
364364
feed_dict = {"input_1:0": x_val}
365365
input_names_with_port = ["input_1:0"]
366366
output_names_with_port = ["cell_state:0"]
367-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
367+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
368368

369369
def test_dynamic_bidirectional_but_one_gru(self):
370370
units = 5
@@ -390,7 +390,7 @@ def test_dynamic_bidirectional_but_one_gru(self):
390390
feed_dict = {"input_1:0": x_val}
391391
input_names_with_port = ["input_1:0"]
392392
output_names_with_port = ["output:0", "cell_state:0"]
393-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
393+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
394394

395395
def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
396396
units = 5
@@ -440,7 +440,7 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
440440
feed_dict = {"input_1:0": x_val}
441441
input_names_with_port = ["input_1:0"]
442442
output_names_with_port = ["cell_state:0"]
443-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
443+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
444444

445445

446446
if __name__ == '__main__':

0 commit comments

Comments
 (0)