Skip to content

Commit 8895b31

Browse files
committed
fix pylint
1 parent 6cc58e5 commit 8895b31

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

tests/test_cudnn.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,17 @@
1212
import tensorflow as tf
1313

1414
from tensorflow.python.ops import init_ops
15-
from tensorflow.python.ops import variable_scope
1615
from backend_test_base import Tf2OnnxBackendTestBase
1716
from common import skip_tf2, skip_tf_cpu, check_opset_min_version, unittest_main
18-
from tf2onnx.tf_loader import is_tf2
1917

2018

2119
class CudnnTests(Tf2OnnxBackendTestBase):
22-
# test cudnn cases
20+
""" test cudnn cases """
2321
@skip_tf2()
2422
@skip_tf_cpu("only tf_gpu can run CudnnGPU")
2523
@check_opset_min_version(11, "CudnnGRU")
2624
def test_cudnngru(self):
27-
# test contrib cudnn gru
25+
""" test contrib cudnn gru """
2826
seq_length = 3
2927
batch_size = 5
3028
input_size = 2

tf2onnx/onnx_opset/rnn.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,14 @@ def version_11(cls, ctx, node, **kwargs):
192192
"input mode must be linear input"
193193
)
194194
num_dirs = 1 if node.attr["direction"].s == b"unidirectional" else 2
195-
num_layers = int(h_shape[0]/num_dirs)
195+
num_layers = int(h_shape[0] / num_dirs)
196196
num_units = hidden_size = h_shape[2]
197197
input_size = x_shape[2]
198-
w_shape = [num_layers*num_dirs, 3*hidden_size, input_size]
198+
w_shape = [num_layers * num_dirs, 3 * hidden_size, input_size]
199199
w_shape_const = ctx.make_const(utils.make_name("w_shape"), np.array(w_shape, dtype=np.int64))
200-
r_shape = [num_layers*num_dirs, 3*hidden_size, hidden_size]
200+
r_shape = [num_layers * num_dirs, 3 * hidden_size, hidden_size]
201201
r_shape_const = ctx.make_const(utils.make_name("r_shape"), np.array(r_shape, dtype=np.int64))
202-
b_shape = [num_layers*num_dirs, 6*hidden_size]
202+
b_shape = [num_layers * num_dirs, 6 * hidden_size]
203203
b_shape_const = ctx.make_const(utils.make_name("b_shape"), np.array(b_shape, dtype=np.int64))
204204
zero_const = ctx.make_const(utils.make_name("zero"), np.array([0], dtype=np.int64))
205205
w_end = np.prod(w_shape)
@@ -208,13 +208,15 @@ def version_11(cls, ctx, node, **kwargs):
208208
r_end_const = ctx.make_const(utils.make_name("r_end"), np.array([r_end], dtype=np.int64))
209209
b_end = r_end + np.prod(b_shape)
210210
b_end_const = ctx.make_const(utils.make_name("b_end"), np.array([b_end], dtype=np.int64))
211-
def Name(nm):
211+
212+
def name(nm):
212213
return node.name + "_" + nm
213-
ws = [Name('W_' + str(i)) for i in range(num_layers*num_dirs)]
214-
rs = [Name('R_' + str(i)) for i in range(num_layers*num_dirs)]
215-
bs = [Name('B_' + str(i)) for i in range(num_layers*num_dirs)]
216-
hs = [Name('H_' + str(i)) for i in range(num_layers*num_dirs)]
217-
yhs = [Name('YH_' + str(i)) for i in range(num_layers*num_dirs)]
214+
215+
ws = [name('W_' + str(i)) for i in range(num_layers * num_dirs)]
216+
rs = [name('R_' + str(i)) for i in range(num_layers * num_dirs)]
217+
bs = [name('B_' + str(i)) for i in range(num_layers * num_dirs)]
218+
hs = [name('H_' + str(i)) for i in range(num_layers * num_dirs)]
219+
yhs = [name('YH_' + str(i)) for i in range(num_layers * num_dirs)]
218220
w_flattened = ctx.make_node('Slice', [p, zero_const.output[0], w_end_const.output[0]])
219221
r_flattened = ctx.make_node('Slice', [p, w_end_const.output[0], r_end_const.output[0]])
220222
b_flattened = ctx.make_node('Slice', [p, r_end_const.output[0], b_end_const.output[0]])
@@ -230,19 +232,21 @@ def Name(nm):
230232
ctx.make_node('Split', [h], outputs=hs)
231233
xnf = xnb = x
232234
for i in range(num_layers):
233-
suffix = '_' + str(i*num_dirs)
234-
ctx.make_node('GRU', [xnf, Name('W' + suffix), Name('R' + suffix), Name('B' + suffix), '', Name('H'+ suffix)],
235-
outputs=[Name('Y' + suffix), Name('YH' + suffix)],
235+
suffix = '_' + str(i * num_dirs)
236+
ctx.make_node('GRU',
237+
[xnf, name('W' + suffix), name('R' + suffix), name('B' + suffix), '', name('H' + suffix)],
238+
outputs=[name('Y' + suffix), name('YH' + suffix)],
236239
attr={'direction': 'forward', 'hidden_size': num_units})
237-
xnf = Name(x + suffix)
238-
ctx.make_node('Squeeze', [Name('Y' + suffix)], outputs=[xnf], attr={'axes': [1]})
240+
xnf = name(x + suffix)
241+
ctx.make_node('Squeeze', [name('Y' + suffix)], outputs=[xnf], attr={'axes': [1]})
239242
if num_dirs == 2:
240-
suffix = '_' + str(i*2+1)
241-
ctx.make_node('GRU', [xnb, Name('W' + suffix), Name('R' + suffix), Name('B' + suffix), '', Name('H'+ suffix)],
242-
outputs=[Name('Y' + suffix), Name('YH' + suffix)],
243+
suffix = '_' + str(i * 2 + 1)
244+
ctx.make_node('GRU',
245+
[xnb, name('W' + suffix), name('R' + suffix), name('B' + suffix), '', name('H' + suffix)],
246+
outputs=[name('Y' + suffix), name('YH' + suffix)],
243247
attr={'direction': 'reverse', 'hidden_size': num_units})
244-
xnb = Name(x + suffix)
245-
ctx.make_node('Squeeze', [Name('Y' + suffix)], outputs=[xnb], attr={'axes': [1]})
248+
xnb = name(x + suffix)
249+
ctx.make_node('Squeeze', [name('Y' + suffix)], outputs=[xnb], attr={'axes': [1]})
246250
ctx.remove_node(node.name)
247251
if num_dirs == 2:
248252
ctx.make_node('Concat', [xnf, xnb], outputs=[node.output[0]], attr={'axis': -1})

0 commit comments

Comments
 (0)