Skip to content

Commit da45397

Browse files
committed
fix comment
1 parent 65dd55d commit da45397

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

tf2onnx/custom_opsets/onnx_ml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
@tf_op("HashTableV2")
1111
class HashTable:
1212
@classmethod
13-
def version_1(cls, ctx, node, **kwargs):
13+
def version_8(cls, ctx, node, **kwargs):
1414
""" HashTable will be removed """
1515
pass
1616

1717

1818
@tf_op("LookupTableFindV2")
1919
class LookupTableFind:
2020
@classmethod
21-
def version_1(cls, ctx, node, **kwargs):
21+
def version_8(cls, ctx, node, **kwargs):
2222
""" convert lookup to category mapper """
2323
table_node = node.inputs[0]
2424
file_path = table_node.get_attr_value("shared_name")[11:-6]

tf2onnx/onnx_opset/controlflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,13 @@ def version_9(cls, ctx, node, **kwargs):
369369
@tf_op("IteratorV2")
370370
class Iterator:
371371
@classmethod
372-
def version_1(cls, ctx, node, **kwargs):
372+
def version_8(cls, ctx, node, **kwargs):
373373
ctx.remove_node(node.name)
374374

375375
@tf_op("IteratorGetNext")
376376
class IteratorGetNext:
377377
@classmethod
378-
def version_1(cls, ctx, node, **kwargs):
378+
def version_8(cls, ctx, node, **kwargs):
379379
output_names = node.output
380380
ctx.remove_node(node.name)
381381
output_types = list(node.get_attr('output_types').ints)

tf2onnx/tfonnx.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def tflist_to_onnx(node_list, shape_override):
8080
if dtype:
8181
if not isinstance(dtype, list):
8282
dtypes[node.name] = utils.map_tf_dtype(dtype)
83-
elif a in ["output_type", "output_dtype", "out_type", "Tidx", "out_idx"]:
83+
elif a in ["output_type", "output_dtype", "out_type", "Tidx",
84+
"out_idx", "key_dtype", "value_dtype", "Tin", "Tout"]:
8485
# Tidx is used by Range
8586
# out_idx is used by ListDiff
8687
attr[a] = utils.map_tf_dtype(utils.get_tf_node_attr(node, a))
@@ -99,8 +100,6 @@ def tflist_to_onnx(node_list, shape_override):
99100
continue
100101
elif a in ignored_attr:
101102
continue
102-
elif a in ["key_dtype", "value_dtype", "Tin", "Tout"]:
103-
attr[a] = utils.map_tf_dtype(utils.get_tf_node_attr(node, a))
104103
elif a == "output_types":
105104
attr[a] = [utils.map_tf_dtype(v) for v in utils.get_tf_node_attr(node, a)]
106105
elif a == "output_shapes":

0 commit comments

Comments
 (0)