Skip to content

Commit 78c8fc1

Browse files
committed
add test case
1 parent c2a0766 commit 78c8fc1

File tree

4 files changed

+28
-9
lines changed

4 files changed

+28
-9
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
8787
graph_def = None
8888
if convert_var_to_const:
8989
with tf.Session() as sess:
90+
tf.tables_initializer().run()
9091
variables_lib.global_variables_initializer().run()
9192
output_name_without_port = [n.split(':')[0] for n in output_names_with_port]
9293
graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
@@ -96,6 +97,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
9697
tf.import_graph_def(graph_def, name='')
9798

9899
with tf.Session() as sess:
100+
tf.tables_initializer().run()
99101
variables_lib.global_variables_initializer().run()
100102
output_dict = []
101103
for out_name in output_names_with_port:

tests/test_backend.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
import unittest
1111
from itertools import product
1212

13+
import os
1314
import numpy as np
1415
import tensorflow as tf
1516

17+
from tensorflow.python.ops import lookup_ops
1618
from backend_test_base import Tf2OnnxBackendTestBase
1719
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
1820
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
@@ -109,7 +111,7 @@ def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
109111
kwargs["convert_var_to_const"] = False
110112
kwargs["constant_fold"] = False
111113
return self.run_test_case(feed_dict, [], output_names_with_port, **kwargs)
112-
114+
'''
113115
def _test_expand_dims_known_rank(self, idx):
114116
tf.reset_default_graph()
115117
x_val = make_xval([3, 4])
@@ -2973,7 +2975,20 @@ def test_Conv2DBackpropInput_valid(self):
29732975
_ = tf.nn.conv2d_backprop_input(input_sizes, filters, out_backprop, strides=[1, 1, 1, 1], padding='VALID',
29742976
name=_TFOUTPUT)
29752977
self._run_test_case([_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filters_val, _INPUT2: out_backprop_val})
2976-
2978+
'''
2979+
@check_opset_min_version(1, "CategoryMapper")
2980+
def test_hashtable_lookup(self):
2981+
filnm = "vocab.tmp"
2982+
words = ["apple","pear","banana","cherry","grape"]
2983+
query = np.array(['cherry'],dtype=object)
2984+
with open(filnm, "w") as f:
2985+
for word in words:
2986+
f.write(word + "\n")
2987+
query_holder = tf.placeholder(tf.string, shape=[len(query)], name=_TFINPUT)
2988+
hash_table = lookup_ops.index_table_from_file(filnm)
2989+
lookup_results = hash_table.lookup(query_holder)
2990+
self._run_test_case([lookup_results.name], {_INPUT: query})
2991+
os.remove(filnm)
29772992

29782993
if __name__ == '__main__':
29792994
unittest_main()

tf2onnx/custom_opsets/onnx_ml.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
@tf_op("HashTableV2")
1111
class HashTable:
1212
@classmethod
13-
def version_11(cls, ctx, node, **kwargs):
13+
def version_1(cls, ctx, node, **kwargs):
1414
""" HashTable will be removed """
1515
pass
1616

1717

1818
@tf_op("LookupTableFindV2")
1919
class LookupTableFind:
2020
@classmethod
21-
def version_11(cls, ctx, node, **kwargs):
21+
def version_1(cls, ctx, node, **kwargs):
2222
""" convert lookup to category mapper """
2323
table_node = node.inputs[0]
2424
file_path = table_node.get_attr_value("shared_name")[11:-6]
@@ -28,12 +28,14 @@ def version_11(cls, ctx, node, **kwargs):
2828
for i, s in enumerate(f.readlines()):
2929
cats_int64s.append(i)
3030
cats_strings.append(s.strip())
31+
node_name = node.name
3132
node_inputs = node.input
3233
node_outputs = node.output
3334
ctx.remove_node(node.name)
34-
ctx.make_node("CategoryMapper", domain=constants.AI_ONNX_ML_DOMAIN,
35-
inputs=node_inputs[1: 2], outputs=node_outputs,
36-
attr={'cats_int64s': cats_int64s, 'cats_strings': cats_strings})
35+
new_node = ctx.make_node("CategoryMapper", domain=constants.AI_ONNX_ML_DOMAIN,
36+
name=node_name, inputs=node_inputs[1: 2], outputs=node_outputs,
37+
attr={'cats_int64s': cats_int64s, 'cats_strings': cats_strings})
38+
ctx.set_shape(new_node.name + ":0", [-1])
3739
customer_nodes = ctx.find_output_consumers(table_node.output[0])
3840
if len(customer_nodes) == 0:
3941
ctx.remove_node(table_node.name)

tf2onnx/onnx_opset/controlflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,13 @@ def version_9(cls, ctx, node, **kwargs):
369369
@tf_op("IteratorV2")
370370
class Iterator:
371371
@classmethod
372-
def version_11(cls, ctx, node, **kwargs):
372+
def version_1(cls, ctx, node, **kwargs):
373373
ctx.remove_node(node.name)
374374

375375
@tf_op("IteratorGetNext")
376376
class IteratorGetNext:
377377
@classmethod
378-
def version_11(cls, ctx, node, **kwargs):
378+
def version_1(cls, ctx, node, **kwargs):
379379
output_names = node.output
380380
ctx.remove_node(node.name)
381381
output_types = list(node.get_attr('output_types').ints)

0 commit comments

Comments
 (0)