Skip to content

Commit 7ac4fb8

Browse files
Merge pull request #1149 from onnx/tom/HashtableLookupOnTf2
Enabled test_hashtable_lookup for tf2
2 parents c511411 + 855f079 commit 7ac4fb8

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

tests/backend_test_base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@
3030
from tf2onnx.graph import ExternalTensorStorage
3131

3232

33+
if is_tf2():
34+
tf_set_random_seed = tf.compat.v1.set_random_seed
35+
tf_tables_initializer = tf.compat.v1.tables_initializer
36+
else:
37+
tf_set_random_seed = tf.set_random_seed
38+
tf_tables_initializer = tf.tables_initializer
39+
40+
3341
class Tf2OnnxBackendTestBase(unittest.TestCase):
3442
def setUp(self):
3543
self.config = get_test_config()
@@ -133,14 +141,13 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
133141
# use graph to execute the tensorflow func
134142
#
135143
with tf_session() as sess:
136-
tf.set_random_seed(1)
144+
tf_set_random_seed(1)
137145
input_list = []
138146
for k, v in clean_feed_dict.items():
139147
input_list.append(tf_placeholder(name=k, shape=v.shape, dtype=tf.as_dtype(v.dtype)))
140148
func(*input_list)
141149
variables_lib.global_variables_initializer().run()
142-
if not is_tf2():
143-
tf.tables_initializer().run()
150+
tf_tables_initializer().run()
144151
output_dict = []
145152
for out_name in output_names_with_port:
146153
output_dict.append(sess.graph.get_tensor_by_name(out_name))

tests/test_backend.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3416,7 +3416,6 @@ def func(value, filters, output_shape):
34163416
rtol=1e-6)
34173417

34183418
@check_opset_min_version(8, "CategoryMapper")
3419-
@skip_tf2()
34203419
def test_hashtable_lookup(self):
34213420
filnm = "vocab.tmp"
34223421
words = ["apple", "pear", "banana", "cherry", "grape"]
@@ -3429,7 +3428,7 @@ def func(query_holder):
34293428
lookup_results = hash_table.lookup(query_holder)
34303429
ret = tf.add(lookup_results, 0, name=_TFOUTPUT)
34313430
return ret
3432-
self._run_test_case(func, [_OUTPUT], {_INPUT: query}, constant_fold=False)
3431+
self._run_test_case(func, [_OUTPUT], {_INPUT: query}, constant_fold=False, as_session=True)
34333432
os.remove(filnm)
34343433

34353434
@check_opset_min_version(11)

0 commit comments

Comments
 (0)