Skip to content

Commit fc75c26

Browse files
committed
Fix pylint errors and unit tests
- fixed pylint errors - fixed some broken for-loop logic
1 parent 723e3a4 commit fc75c26

File tree

3 files changed

+25
-19
lines changed

3 files changed

+25
-19
lines changed

tests/test_stacked_lstm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55

66
from tensorflow.python.ops import init_ops
77
from backend_test_base import Tf2OnnxBackendTestBase
8-
from common import unittest_main, check_lstm_count, skip_tf2
8+
from common import unittest_main, check_lstm_count
99

1010
from tf2onnx.tf_loader import is_tf2
1111

12+
13+
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
14+
# pylint: disable=invalid-name
15+
1216
if is_tf2():
1317
LSTMCell = tf.compat.v1.nn.rnn_cell.LSTMCell
1418
MultiRNNCell = tf.compat.v1.nn.rnn_cell.MultiRNNCell
@@ -19,6 +23,7 @@
1923
MultiRNNCell = tf.contrib.rnn.MultiRNNCell
2024
dynamic_rnn = tf.nn.dynamic_rnn
2125

26+
2227
class LSTMLayeredTests(Tf2OnnxBackendTestBase):
2328
def test_layered_lstm(self):
2429
units = 5
@@ -52,5 +57,6 @@ def lstm_cell():
5257
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
5358
graph_validator=lambda g: check_lstm_count(g, 2))
5459

60+
5561
if __name__ == '__main__':
56-
unittest_main()
62+
unittest_main()

tf2onnx/rewriter/lstm_rewriter.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ def find_cell(self, context):
4040
cell_match = self._match_cell(context, cell_type)
4141
if cell_match and len(cell_match) >= 1:
4242
self.num_lstm_layers = len(cell_match)
43-
logger.debug("number of lstm layers: " + str(self.num_lstm_layers))
43+
logger.debug("number of LSTM layers: %s", self.num_lstm_layers)
4444
for i in range(self.num_lstm_layers):
4545
self.state_variable_handlers.append({
4646
"ct" + str(i): (self._ct_variable_finder, self._connect_lstm_yc_to_graph, i),
4747
"ht" + str(i): (self._ht_variable_finder, self._connect_lstm_yh_to_graph, i)
4848
})
4949
self.state_variable_handlers.append({
50-
"ct_ht" + str(i): (self._ct_ht_shared_variable_finder, self._connect_lstm_ych_to_graph, i)
50+
"ct_ht" + str(i): (self._ct_ht_shared_variable_finder, self._connect_lstm_ych_to_graph, i)
5151
})
5252
logger.debug("parsing unit is %s, num layers is %d", cell_type, self.num_lstm_layers)
5353
if cell_match:
@@ -287,9 +287,9 @@ def process_var_init_nodes(self, context):
287287
def process_var_init_nodes_per_layer(self, context, i):
288288
init_h_id = None
289289
init_c_id = None
290-
if ("ct_ht" + str(i)) in context.state_variables:
290+
if "ct_ht" + str(i) in context.state_variables:
291291
init_h_id, init_c_id = self._process_non_tuple_ch_init_nodes(context, i)
292-
elif ("ct" + str(i)) in context.state_variables and ("ht" + str(i)) in context.state_variables:
292+
elif "ct" + str(i) in context.state_variables and ("ht" + str(i)) in context.state_variables:
293293
init_h_id, init_c_id = self._process_tuple_ch_init_nodes(context, i)
294294
else:
295295
raise ValueError("no initializers, unexpected")
@@ -363,11 +363,10 @@ def create_single_rnn_node(self, context, i):
363363
def create_rnn_node(self, context):
364364
rnn_nodes = list()
365365
outputs = context.loop_properties.scan_outputs_exits
366-
logger.debug("number of rnn node outputs:" + str(len(outputs)))
367-
for i in range(len(outputs)):
368-
logger.debug("output " + str(i) + " with id=" + outputs[i].id)
366+
logger.debug("number of rnn node outputs: %s", len(outputs))
367+
369368
for i in range(self.num_lstm_layers):
370-
logger.debug("creating rnn node for layer: " + str(i))
369+
logger.debug("creating rnn node for layer: %s", i)
371370
rnn_nodes.append(self.create_single_rnn_node(context, i))
372371
output_id = rnn_nodes[i].output[0]
373372
rnn_output_shape = self.g.get_shape(output_id)
@@ -376,7 +375,7 @@ def create_rnn_node(self, context):
376375
shapes=[squeeze_output_shape],
377376
dtypes=[self.g.get_dtype(output_id)])
378377
if i + 1 < self.num_lstm_layers:
379-
logger.debug("setting input for layer: " + str(i + 1))
378+
logger.debug("setting input for layer: %s", i + 1)
380379
context.onnx_input_ids[i + 1]["X"] = squeeze_node.output[0]
381380
return rnn_nodes
382381

@@ -420,4 +419,4 @@ def _connect_lstm_ych_to_graph(self, context, i):
420419
shapes=[squeeze_output_shape],
421420
dtypes=[self.g.get_dtype(concat.output[0])])
422421

423-
self.g.replace_all_inputs(self.g.get_nodes(), exit_output.id, squeeze_node.output[0])
422+
self.g.replace_all_inputs(self.g.get_nodes(), exit_output.id, squeeze_node.output[0])

tf2onnx/rewriter/lstm_rewriter_base.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
logger = logging.getLogger(__name__)
2222

2323

24-
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,broad-except,protected-access
24+
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,broad-except,protected-access,W0223
2525

2626
class LSTMContext(UnitRnnContext):
2727
def __init__(self):
@@ -45,6 +45,7 @@ class LSTMRewriterBase(UnitRnnRewriterBase):
4545
2 find needed info from tensorflow graph
4646
3 process found info according to ONNX requirement
4747
"""
48+
4849
def __init__(self, g):
4950
super(LSTMRewriterBase, self).__init__(g)
5051
# {var_name: (finder, connector)}
@@ -100,7 +101,8 @@ def parse_unit_rnn(self, context):
100101
context.input_size.append(None)
101102
context.hidden_size.append(None)
102103
context.attributes.append({})
103-
context.onnx_input_ids[i]["sequence_lens"] = seq_len_node.output[0] if seq_len_node else utils.ONNX_EMPTY_INPUT
104+
context.onnx_input_ids[i]["sequence_lens"] = \
105+
seq_len_node.output[0] if seq_len_node else utils.ONNX_EMPTY_INPUT
104106

105107
context.onnx_input_ids[0]["X"] = inputs[0]
106108
if not self.parse_attributes(context):
@@ -126,10 +128,9 @@ def _match_cell(self, context, unittype):
126128
)
127129

128130
match_results = list(matcher.match_ops(body_graph_ops))
129-
logger.debug("number of match results: " + str(len(match_results)))
130-
if len(match_results) < 1:
131-
return None
132-
return match_results
131+
logger.debug("number of match results: %s", len(match_results))
132+
if len(match_results) > 0:
133+
return match_results
133134
return None
134135

135136
def get_state_variables(self, context):
@@ -186,4 +187,4 @@ def connect_unit_rnn_output_to_graph(self, context):
186187
squeeze_node = self.g.make_node("Squeeze", [output_id], attr={"axes": [1]},
187188
shapes=[squeeze_output_shape],
188189
dtypes=[self.g.get_dtype(output_id)])
189-
self.g.replace_all_inputs(self.g.get_nodes(), gather_output_id, squeeze_node.output[0])
190+
self.g.replace_all_inputs(self.g.get_nodes(), gather_output_id, squeeze_node.output[0])

0 commit comments

Comments
 (0)