Skip to content

Commit 19dc2e2

Browse files
committed
fix pylint
1 parent 4d7cf30 commit 19dc2e2

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

tf2onnx/custom_opsets/onnx_ml.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,35 @@
44
import numpy as np
55

66
from onnx import onnx_pb
7-
from onnx.onnx_pb import TensorProto
8-
from tf2onnx import constants, utils
7+
from tf2onnx import constants
98
from tf2onnx.handler import tf_op
109

1110
@tf_op("HashTableV2")
1211
class HashTable:
1312
@classmethod
1413
def version_11(cls, ctx, node, **kwargs):
14+
""" HashTable will be removed """
1515
pass
1616

1717
@tf_op("LookupTableFindV2")
1818
class LookupTableFind:
1919
@classmethod
2020
def version_11(cls, ctx, node, **kwargs):
21+
""" convert lookup to category mapper """
2122
table_node = node.inputs[0]
2223
file_path = table_node.get_attr_value("shared_name")[11:-6]
2324
cats_int64s = []
2425
cats_strings = []
2526
with open(file_path, 'r') as f:
26-
for i,s in enumerate(f.readlines()):
27+
for i, s in enumerate(f.readlines()):
2728
cats_int64s.append(i)
2829
cats_strings.append(s.strip())
2930
node_inputs = node.input
3031
node_outputs = node.output
3132
ctx.remove_node(node.name)
32-
ctx.make_node("CategoryMapper", domain=constants.AI_ONNX_ML_DOMAIN, inputs=node_inputs[1:2], outputs=node_outputs, attr={'cats_int64s':cats_int64s,'cats_strings':cats_strings})
33+
ctx.make_node("CategoryMapper", domain=constants.AI_ONNX_ML_DOMAIN,
34+
inputs=node_inputs[1: 2], outputs=node_outputs,
35+
attr={'cats_int64s': cats_int64s, 'cats_strings': cats_strings})
3336
customer_nodes = ctx.find_output_consumers(table_node.output[0])
3437
if len(customer_nodes) == 0:
3538
ctx.remove_node(table_node.name)

tf2onnx/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def get_tf_shape_attr(node):
179179

180180
def get_tf_output_shapes_attr(node):
181181
"""Get output shapes from tensorflow attr "output_shapes"."""
182-
dims = []
182+
dims = []
183183
try:
184184
shapes = get_tf_node_attr(node, "output_shapes")
185185
for shape in shapes:

0 commit comments

Comments
 (0)