Skip to content

Commit 8dde75c

Browse files
authored
Merge pull request #786 from onnx/rashuai/gnmt2
Convert GNMT
2 parents 2af4841 + 95980c4 commit 8dde75c

File tree

6 files changed

+82
-1
lines changed

6 files changed

+82
-1
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
8787
graph_def = None
8888
if convert_var_to_const:
8989
with tf.Session() as sess:
90+
tf.tables_initializer().run()
9091
variables_lib.global_variables_initializer().run()
9192
output_name_without_port = [n.split(':')[0] for n in output_names_with_port]
9293
graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
@@ -96,6 +97,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
9697
tf.import_graph_def(graph_def, name='')
9798

9899
with tf.Session() as sess:
100+
tf.tables_initializer().run()
99101
variables_lib.global_variables_initializer().run()
100102
output_dict = []
101103
for out_name in output_names_with_port:

tests/test_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
import unittest
1111
from itertools import product
1212

13+
import os
1314
import numpy as np
1415
import tensorflow as tf
1516

17+
from tensorflow.python.ops import lookup_ops
1618
from backend_test_base import Tf2OnnxBackendTestBase
1719
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
1820
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
@@ -2974,6 +2976,20 @@ def test_Conv2DBackpropInput_valid(self):
29742976
name=_TFOUTPUT)
29752977
self._run_test_case([_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filters_val, _INPUT2: out_backprop_val})
29762978

2979+
@check_opset_min_version(8, "CategoryMapper")
2980+
def test_hashtable_lookup(self):
2981+
filnm = "vocab.tmp"
2982+
words = ["apple", "pear", "banana", "cherry", "grape"]
2983+
query = np.array(['cherry'], dtype=object)
2984+
with open(filnm, "w") as f:
2985+
for word in words:
2986+
f.write(word + "\n")
2987+
query_holder = tf.placeholder(tf.string, shape=[len(query)], name=_TFINPUT)
2988+
hash_table = lookup_ops.index_table_from_file(filnm)
2989+
lookup_results = hash_table.lookup(query_holder)
2990+
self._run_test_case([lookup_results.name], {_INPUT: query})
2991+
os.remove(filnm)
2992+
29772993

29782994
if __name__ == '__main__':
29792995
unittest_main()

tf2onnx/custom_opsets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
""" custom tf2onnx mapping functions. """
44

55
from . import ms
6+
from . import onnx_ml

tf2onnx/custom_opsets/onnx_ml.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
""" tf2onnx mapping functions for onnx ml domain. """
4+
from tf2onnx import constants
5+
from tf2onnx.handler import tf_op
6+
7+
8+
# pylint: disable=unused-argument,missing-docstring,unnecessary-pass
9+
10+
@tf_op("HashTableV2")
11+
class HashTable:
12+
@classmethod
13+
def version_8(cls, ctx, node, **kwargs):
14+
""" HashTable will be removed """
15+
pass
16+
17+
18+
@tf_op("LookupTableFindV2")
19+
class LookupTableFind:
20+
@classmethod
21+
def version_8(cls, ctx, node, **kwargs):
22+
""" convert lookup to category mapper """
23+
table_node = node.inputs[0]
24+
file_path = table_node.get_attr_value("shared_name")[11:-6]
25+
cats_int64s = []
26+
cats_strings = []
27+
with open(file_path, 'r') as f:
28+
for i, s in enumerate(f.readlines()):
29+
cats_int64s.append(i)
30+
cats_strings.append(s.strip())
31+
node_name = node.name
32+
node_inputs = node.input
33+
node_outputs = node.output
34+
ctx.remove_node(node.name)
35+
new_node = ctx.make_node("CategoryMapper", domain=constants.AI_ONNX_ML_DOMAIN,
36+
name=node_name, inputs=node_inputs[1: 2], outputs=node_outputs,
37+
attr={'cats_int64s': cats_int64s, 'cats_strings': cats_strings})
38+
ctx.set_shape(new_node.name + ":0", [-1])
39+
customer_nodes = ctx.find_output_consumers(table_node.output[0])
40+
if len(customer_nodes) == 0:
41+
ctx.remove_node(table_node.name)

tf2onnx/onnx_opset/controlflow.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,22 @@ def version_9(cls, ctx, node, **kwargs):
365365
node.output[0], name=utils.make_name("where_op_added"))
366366
ctx.copy_shape(node.output[0], transpose_node.output[0])
367367
ctx.copy_dtype(node.output[0], transpose_node.output[0])
368+
369+
@tf_op("IteratorV2")
370+
class Iterator:
371+
@classmethod
372+
def version_8(cls, ctx, node, **kwargs):
373+
ctx.remove_node(node.name)
374+
375+
@tf_op("IteratorGetNext")
376+
class IteratorGetNext:
377+
@classmethod
378+
def version_8(cls, ctx, node, **kwargs):
379+
output_names = node.output
380+
type_0 = ctx.get_dtype(output_names[0])
381+
type_1 = ctx.get_dtype(output_names[1])
382+
shape_0 = ctx.get_shape(output_names[0])
383+
shape_1 = ctx.get_shape(output_names[1])
384+
ctx.remove_node(node.name)
385+
ctx.add_graph_input(output_names[0], type_0, shape_0)
386+
ctx.add_graph_input(output_names[1], type_1, shape_1)

tf2onnx/tfonnx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def tflist_to_onnx(node_list, shape_override):
4646
ignored_attr = ["unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
4747
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "dynamic_size", "Tmultiples",
4848
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
49-
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "T_threshold"]
49+
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "T_threshold",
50+
"output_types", "output_shapes", "key_dtype", "value_dtype", "Tin", "Tout"]
51+
5052
# some stats
5153
op_cnt = collections.Counter()
5254
attr_cnt = collections.Counter()

0 commit comments

Comments
 (0)