Skip to content

Commit 4d7cf30

Browse files
committed
Convert lookup op
1 parent 2af4841 commit 4d7cf30

File tree

5 files changed

+70
-0
lines changed

5 files changed

+70
-0
lines changed

tf2onnx/custom_opsets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
""" custom tf2onnx mapping functions. """
44

55
from . import ms
6+
from . import onnx_ml

tf2onnx/custom_opsets/onnx_ml.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
""" tf2onnx mapping functions for onnx ml domain. """
4+
import numpy as np
5+
6+
from onnx import onnx_pb
7+
from onnx.onnx_pb import TensorProto
8+
from tf2onnx import constants, utils
9+
from tf2onnx.handler import tf_op
10+
11+
@tf_op("HashTableV2")
12+
class HashTable:
13+
@classmethod
14+
def version_11(cls, ctx, node, **kwargs):
15+
pass
16+
17+
@tf_op("LookupTableFindV2")
18+
class LookupTableFind:
19+
@classmethod
20+
def version_11(cls, ctx, node, **kwargs):
21+
table_node = node.inputs[0]
22+
file_path = table_node.get_attr_value("shared_name")[11:-6]
23+
cats_int64s = []
24+
cats_strings = []
25+
with open(file_path, 'r') as f:
26+
for i,s in enumerate(f.readlines()):
27+
cats_int64s.append(i)
28+
cats_strings.append(s.strip())
29+
node_inputs = node.input
30+
node_outputs = node.output
31+
ctx.remove_node(node.name)
32+
ctx.make_node("CategoryMapper", domain=constants.AI_ONNX_ML_DOMAIN, inputs=node_inputs[1:2], outputs=node_outputs, attr={'cats_int64s':cats_int64s,'cats_strings':cats_strings})
33+
customer_nodes = ctx.find_output_consumers(table_node.output[0])
34+
if len(customer_nodes) == 0:
35+
ctx.remove_node(table_node.name)

tf2onnx/onnx_opset/controlflow.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,20 @@ def version_9(cls, ctx, node, **kwargs):
365365
node.output[0], name=utils.make_name("where_op_added"))
366366
ctx.copy_shape(node.output[0], transpose_node.output[0])
367367
ctx.copy_dtype(node.output[0], transpose_node.output[0])
368+
369+
@tf_op("IteratorV2")
370+
class Iterator:
371+
@classmethod
372+
def version_11(cls, ctx, node, **kwargs):
373+
ctx.remove_node(node.name)
374+
375+
@tf_op("IteratorGetNext")
376+
class IteratorGetNext:
377+
@classmethod
378+
def version_11(cls, ctx, node, **kwargs):
379+
output_names = node.output
380+
ctx.remove_node(node.name)
381+
output_types = list(node.get_attr('output_types').ints)
382+
output_shapes = list(node.get_attr('output_shapes').ints)
383+
ctx.add_graph_input(output_names[0], output_types[0], output_shapes[:output_shapes.index(0)])
384+
ctx.add_graph_input(output_names[1], output_types[1], output_shapes[output_shapes.index(0)+1:-1])

tf2onnx/tfonnx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ def tflist_to_onnx(node_list, shape_override):
9999
continue
100100
elif a in ignored_attr:
101101
continue
102+
elif a in ["key_dtype", "value_dtype", "Tin", "Tout"]:
103+
attr[a] = utils.map_tf_dtype(utils.get_tf_node_attr(node, a))
104+
elif a == "output_types":
105+
attr[a] = [utils.map_tf_dtype(v) for v in utils.get_tf_node_attr(node, a)]
106+
elif a == "output_shapes":
107+
attr[a] = utils.get_tf_output_shapes_attr(node)
102108
else:
103109
attr[a] = utils.get_tf_node_attr(node, a)
104110

tf2onnx/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,17 @@ def get_tf_shape_attr(node):
177177
pass
178178
return dims
179179

180+
def get_tf_output_shapes_attr(node):
181+
"""Get output shapes from tensorflow attr "output_shapes"."""
182+
dims = []
183+
try:
184+
shapes = get_tf_node_attr(node, "output_shapes")
185+
for shape in shapes:
186+
dims.extend([d.size for d in shape.dim])
187+
dims.append(0)
188+
except: # pylint: disable=bare-except
189+
pass
190+
return dims
180191

181192
def get_tf_tensor_shape(tensor):
182193
shape = []

0 commit comments

Comments
 (0)