Skip to content

Commit 38c7798

Browse files
author
wayuanho
committed
check specific op (lstm, gru) in converted graph
1 parent 09e839c commit 38c7798

File tree

6 files changed

+173
-52
lines changed

6 files changed

+173
-52
lines changed

tests/backend_test_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _run_backend(self, g, outputs, input_dict):
9393

9494
def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=0.,
9595
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=False,
96-
check_dtype=True, process_args=None, onnx_feed_dict=None):
96+
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None):
9797
# optional - passed to process_tf_graph
9898
if process_args is None:
9999
process_args = {}
@@ -150,6 +150,9 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
150150
if check_shape:
151151
self.assertEqual(expected_val.shape, actual_val.shape)
152152

153+
if graph_validator:
154+
self.assertTrue(graph_validator(g))
155+
153156
return g
154157

155158
def save_onnx_model(self, model_proto, feed_dict, postfix=""):

tests/common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,15 @@ def group_nodes_by_type(graph):
224224
for node in graph.get_nodes():
225225
res[node.type].append(node)
226226
return res
227+
228+
229+
def check_op_count(graph, op_type, expected_count):
230+
return len(group_nodes_by_type(graph)[op_type]) == expected_count
231+
232+
233+
def check_lstm_count(graph, expected_count):
234+
return check_op_count(graph, "LSTM", expected_count)
235+
236+
237+
def check_gru_count(graph, expected_count):
238+
return check_op_count(graph, "GRU", expected_count)

tests/test_gru.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tensorflow.python.ops import init_ops
1717
from tensorflow.python.ops import variable_scope
1818
from backend_test_base import Tf2OnnxBackendTestBase
19-
from common import unittest_main
19+
from common import unittest_main, check_gru_count
2020

2121

2222
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -47,7 +47,8 @@ def test_single_dynamic_gru(self):
4747
input_names_with_port = ["input_1:0"]
4848
feed_dict = {"input_1:0": x_val}
4949
output_names_with_port = ["output:0", "cell_state:0"]
50-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
50+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06,
51+
graph_validator=lambda g: check_gru_count(g, 1))
5152

5253
def test_multiple_dynamic_gru(self):
5354
units = 5
@@ -93,7 +94,8 @@ def test_multiple_dynamic_gru(self):
9394
feed_dict = {"input_1:0": x_val}
9495
input_names_with_port = ["input_1:0"]
9596
output_names_with_port = ["output:0", "cell_state:0"]
96-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
97+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
98+
graph_validator=lambda g: check_gru_count(g, 2))
9799

98100
def test_single_dynamic_gru_seq_length_is_const(self):
99101
units = 5
@@ -119,7 +121,8 @@ def test_single_dynamic_gru_seq_length_is_const(self):
119121
feed_dict = {"input_1:0": x_val}
120122
input_names_with_port = ["input_1:0"]
121123
output_names_with_port = ["output:0", "cell_state:0"]
122-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
124+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
125+
graph_validator=lambda g: check_gru_count(g, 1))
123126

124127
def test_single_dynamic_gru_seq_length_is_not_const(self):
125128
units = 5
@@ -148,7 +151,8 @@ def test_single_dynamic_gru_seq_length_is_not_const(self):
148151
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
149152
input_names_with_port = ["input_1:0", "input_2:0"]
150153
output_names_with_port = ["output:0", "cell_state:0"]
151-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
154+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06,
155+
graph_validator=lambda g: check_gru_count(g, 1))
152156

153157
def test_single_dynamic_gru_placeholder_input(self):
154158
units = 5
@@ -172,7 +176,8 @@ def test_single_dynamic_gru_placeholder_input(self):
172176
feed_dict = {"input_1:0": x_val}
173177
input_names_with_port = ["input_1:0"]
174178
output_names_with_port = ["output:0", "cell_state:0"]
175-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
179+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06,
180+
graph_validator=lambda g: check_gru_count(g, 1))
176181

177182
def test_single_dynamic_gru_ch_zero_state_initializer(self):
178183
units = 5
@@ -201,7 +206,8 @@ def test_single_dynamic_gru_ch_zero_state_initializer(self):
201206
feed_dict = {"input_1:0": x_val}
202207
input_names_with_port = ["input_1:0"]
203208
output_names_with_port = ["output:0", "cell_state:0"]
204-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
209+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06,
210+
graph_validator=lambda g: check_gru_count(g, 1))
205211

206212
@unittest.skip("FIXME: disable for now for accuracy problem")
207213
def test_single_dynamic_gru_random_weights(self):
@@ -229,7 +235,8 @@ def test_single_dynamic_gru_random_weights(self):
229235
feed_dict = {"input_1:0": x_val}
230236
input_names_with_port = ["input_1:0"]
231237
output_names_with_port = ["output:0", "cell_state:0"]
232-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
238+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001,
239+
graph_validator=lambda g: check_gru_count(g, 1))
233240

234241
@unittest.skip("FIXME: disable for now for accuracy problem")
235242
def test_single_dynamic_gru_random_weights2(self):
@@ -256,7 +263,8 @@ def test_single_dynamic_gru_random_weights2(self):
256263
feed_dict = {"input_1:0": x_val}
257264
input_names_with_port = ["input_1:0"]
258265
output_names_with_port = ["output:0", "cell_state:0"]
259-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.01)
266+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.01,
267+
graph_validator=lambda g: check_gru_count(g, 1))
260268

261269
def test_dynamic_gru_output_consumed_only(self):
262270
units = 5
@@ -280,7 +288,8 @@ def test_dynamic_gru_output_consumed_only(self):
280288
feed_dict = {"input_1:0": x_val}
281289
input_names_with_port = ["input_1:0"]
282290
output_names_with_port = ["output:0"]
283-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
291+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001,
292+
graph_validator=lambda g: check_gru_count(g, 1))
284293

285294
def test_dynamic_gru_state_consumed_only(self):
286295
units = 5
@@ -304,7 +313,8 @@ def test_dynamic_gru_state_consumed_only(self):
304313
feed_dict = {"input_1:0": x_val}
305314
input_names_with_port = ["input_1:0"]
306315
output_names_with_port = ["cell_state:0"]
307-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001, atol=1e-06)
316+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001, atol=1e-06,
317+
graph_validator=lambda g: check_gru_count(g, 1))
308318

309319
def test_dynamic_bigru(self):
310320
units = 5
@@ -335,7 +345,8 @@ def test_dynamic_bigru(self):
335345
feed_dict = {"input_1:0": x_val}
336346
input_names_with_port = ["input_1:0"]
337347
output_names_with_port = ["output:0", "cell_state:0"]
338-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
348+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
349+
graph_validator=lambda g: check_gru_count(g, 1))
339350

340351
def test_dynamic_bigru_output_consumed_only(self):
341352
units = 5
@@ -365,7 +376,8 @@ def test_dynamic_bigru_output_consumed_only(self):
365376
feed_dict = {"input_1:0": x_val}
366377
input_names_with_port = ["input_1:0"]
367378
output_names_with_port = ["output:0"]
368-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
379+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
380+
graph_validator=lambda g: check_gru_count(g, 1))
369381

370382
def test_dynamic_bigru_state_consumed_only(self):
371383
units = 5
@@ -395,7 +407,8 @@ def test_dynamic_bigru_state_consumed_only(self):
395407
feed_dict = {"input_1:0": x_val}
396408
input_names_with_port = ["input_1:0"]
397409
output_names_with_port = ["cell_state:0"]
398-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
410+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
411+
graph_validator=lambda g: check_gru_count(g, 1))
399412

400413
def test_dynamic_bidirectional_but_one_gru(self):
401414
units = 5
@@ -423,7 +436,8 @@ def test_dynamic_bidirectional_but_one_gru(self):
423436
feed_dict = {"input_1:0": x_val}
424437
input_names_with_port = ["input_1:0"]
425438
output_names_with_port = ["output:0", "cell_state:0"]
426-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
439+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
440+
graph_validator=lambda g: check_gru_count(g, 1))
427441

428442
def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
429443
units = 5
@@ -448,7 +462,8 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
448462
feed_dict = {"input_1:0": x_val}
449463
input_names_with_port = ["input_1:0"]
450464
output_names_with_port = ["output:0"]
451-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
465+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
466+
graph_validator=lambda g: check_gru_count(g, 1))
452467

453468
def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
454469
units = 5
@@ -473,7 +488,8 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
473488
feed_dict = {"input_1:0": x_val}
474489
input_names_with_port = ["input_1:0"]
475490
output_names_with_port = ["cell_state:0"]
476-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
491+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
492+
graph_validator=lambda g: check_gru_count(g, 1))
477493

478494

479495
if __name__ == '__main__':

0 commit comments

Comments
 (0)