|
4 | 4 | import numpy as np
|
5 | 5 |
|
6 | 6 | from onnx import onnx_pb
|
7 |
| -from onnx.onnx_pb import TensorProto |
8 |
| -from tf2onnx import constants, utils |
| 7 | +from tf2onnx import constants |
9 | 8 | from tf2onnx.handler import tf_op
|
10 | 9 |
|
11 | 10 | @tf_op("HashTableV2")
|
12 | 11 | class HashTable:
|
13 | 12 | @classmethod
|
14 | 13 | def version_11(cls, ctx, node, **kwargs):
|
| 14 | + """ HashTable will be removed """ |
15 | 15 | pass
|
16 | 16 |
|
17 | 17 | @tf_op("LookupTableFindV2")
|
18 | 18 | class LookupTableFind:
|
19 | 19 | @classmethod
|
20 | 20 | def version_11(cls, ctx, node, **kwargs):
|
| 21 | + """ convert lookup to category mapper """ |
21 | 22 | table_node = node.inputs[0]
|
22 | 23 | file_path = table_node.get_attr_value("shared_name")[11:-6]
|
23 | 24 | cats_int64s = []
|
24 | 25 | cats_strings = []
|
25 | 26 | with open(file_path, 'r') as f:
|
26 |
| - for i,s in enumerate(f.readlines()): |
| 27 | + for i, s in enumerate(f.readlines()): |
27 | 28 | cats_int64s.append(i)
|
28 | 29 | cats_strings.append(s.strip())
|
29 | 30 | node_inputs = node.input
|
30 | 31 | node_outputs = node.output
|
31 | 32 | ctx.remove_node(node.name)
|
32 |
| - ctx.make_node("CategoryMapper", domain=constants.AI_ONNX_ML_DOMAIN, inputs=node_inputs[1:2], outputs=node_outputs, attr={'cats_int64s':cats_int64s,'cats_strings':cats_strings}) |
| 33 | + ctx.make_node("CategoryMapper", domain=constants.AI_ONNX_ML_DOMAIN, |
| 34 | + inputs=node_inputs[1: 2], outputs=node_outputs, |
| 35 | + attr={'cats_int64s': cats_int64s, 'cats_strings': cats_strings}) |
33 | 36 | customer_nodes = ctx.find_output_consumers(table_node.output[0])
|
34 | 37 | if len(customer_nodes) == 0:
|
35 | 38 | ctx.remove_node(table_node.name)
|
0 commit comments