Skip to content

Commit deae044

Browse files
committed
make all test class derived from Tf2OnnxBackendTestBase
1 parent 8b0d5cd commit deae044

File tree

3 files changed

+8
-28
lines changed

3 files changed

+8
-28
lines changed

tests/backend_test_base.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,16 @@
2626
class Tf2OnnxBackendTestBase(unittest.TestCase):
2727
def setUp(self):
2828
self.config = get_test_config()
29-
self.maxDiff = None
3029
tf.reset_default_graph()
3130
# reset name generation on every test
3231
utils.INTERNAL_NAME = 1
3332
np.random.seed(1) # Make it reproducible.
3433

3534
self.log = logging.getLogger("tf2onnx.unitest." + str(type(self)))
36-
if self.config.is_debug_mode:
37-
self.log.setLevel(logging.DEBUG)
38-
else:
35+
if not self.config.is_debug_mode:
3936
# suppress log info of tensorflow so that result of test can be seen much easier
4037
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
4138
tf.logging.set_verbosity(tf.logging.WARN)
42-
self.log.setLevel(logging.INFO)
4339

4440
def tearDown(self):
4541
if not self.config.is_debug_mode:

tests/test_graph.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from __future__ import print_function
88
from __future__ import unicode_literals
99

10-
import os
11-
import unittest
1210
from collections import namedtuple
1311

1412
import numpy as np
@@ -17,13 +15,13 @@
1715
import tensorflow as tf
1816
from onnx import helper
1917

20-
import tf2onnx
2118
from tf2onnx import constants, utils
2219
from tf2onnx.graph import GraphUtil
2320
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2421
from tf2onnx.tfonnx import process_tf_graph
2522
from tf2onnx.handler import tf_op
2623

24+
from backend_test_base import Tf2OnnxBackendTestBase
2725
from common import get_test_config, unittest_main, check_tf_min_version, check_tf_max_version
2826

2927

@@ -91,21 +89,12 @@ def onnx_pretty(g, args=None):
9189
return helper.printable_graph(model_proto.graph)
9290

9391

94-
class Tf2OnnxGraphTests(unittest.TestCase):
92+
class Tf2OnnxGraphTests(Tf2OnnxBackendTestBase):
9593
"""Test cases."""
9694
maxDiff = None
9795

9896
def setUp(self):
99-
"""Setup test."""
100-
# reset name generation on every test
101-
# suppress log info of tensorflow so that result of test can be seen much easier
102-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
103-
tf.logging.set_verbosity(tf.logging.WARN)
104-
105-
self.config = get_test_config()
106-
107-
tf2onnx.utils.INTERNAL_NAME = 1
108-
tf.reset_default_graph()
97+
super().setUp()
10998
arg = namedtuple("Arg", "input inputs outputs verbose continue_on_error")
11099
self._args0 = arg(input="test", inputs=[], outputs=["output:0"],
111100
verbose=False, continue_on_error=False)

tests/test_internals.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from __future__ import print_function
88
from __future__ import unicode_literals
99

10-
import os
11-
import unittest
1210
from collections import namedtuple
1311

1412
import graphviz as gv
@@ -20,6 +18,8 @@
2018
from tf2onnx import utils
2119
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2220
from tf2onnx.graph import GraphUtil
21+
22+
from backend_test_base import Tf2OnnxBackendTestBase
2323
from common import unittest_main
2424

2525

@@ -49,14 +49,9 @@ def onnx_pretty(g, args=None):
4949
return helper.printable_graph(graph_proto.graph)
5050

5151

52-
class Tf2OnnxInternalTests(unittest.TestCase):
52+
class Tf2OnnxInternalTests(Tf2OnnxBackendTestBase):
5353
def setUp(self):
54-
"""Setup test."""
55-
# suppress log info of tensorflow so that result of test can be seen much easier
56-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
57-
tf.logging.set_verbosity(tf.logging.WARN)
58-
59-
utils.INTERNAL_NAME = 1
54+
super().setUp()
6055
arg = namedtuple("Arg", "input inputs outputs verbose")
6156
self._args0 = arg(input="test", inputs=[], outputs=["output:0"], verbose=False)
6257
self._args1 = arg(input="test", inputs=["input:0"], outputs=["output:0"], verbose=False)

0 commit comments

Comments
 (0)