Skip to content

Commit 6cc58e5

Browse files
committed
fix pylint
1 parent 9cdfb72 commit 6cc58e5

File tree

2 files changed

+42
-40
lines changed

2 files changed

+42
-40
lines changed

tests/test_cudnn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
from tensorflow.python.ops import init_ops
1515
from tensorflow.python.ops import variable_scope
1616
from backend_test_base import Tf2OnnxBackendTestBase
17-
from common import *
17+
from common import skip_tf2, skip_tf_cpu, check_opset_min_version, unittest_main
1818
from tf2onnx.tf_loader import is_tf2
1919

2020

2121
class CudnnTests(Tf2OnnxBackendTestBase):
22+
# test cudnn cases
2223
@skip_tf2()
2324
@skip_tf_cpu("only tf_gpu can run CudnnGPU")
2425
@check_opset_min_version(11, "CudnnGRU")
2526
def test_cudnngru(self):
27+
# test contrib cudnn gru
2628
seq_length = 3
2729
batch_size = 5
2830
input_size = 2

tf2onnx/onnx_opset/rnn.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,11 @@ def version_7(cls, ctx, node, **kwargs):
174174
class CudnnRNN:
175175
@classmethod
176176
def version_11(cls, ctx, node, **kwargs):
177-
X = node.input[0]
178-
X_shape = ctx.get_shape(X)
179-
H = node.input[1]
180-
H_shape = ctx.get_shape(H)
181-
P = node.input[3]
177+
x = node.input[0]
178+
x_shape = ctx.get_shape(x)
179+
h = node.input[1]
180+
h_shape = ctx.get_shape(h)
181+
p = node.input[3]
182182
utils.make_sure(
183183
node.attr["rnn_mode"].s == b"gru",
184184
"rnn mode other than gru are not supported yet"
@@ -192,9 +192,9 @@ def version_11(cls, ctx, node, **kwargs):
192192
"input mode must be linear input"
193193
)
194194
num_dirs = 1 if node.attr["direction"].s == b"unidirectional" else 2
195-
num_layers = int(H_shape[0]/num_dirs)
196-
num_units = hidden_size = H_shape[2]
197-
input_size = X_shape[2]
195+
num_layers = int(h_shape[0]/num_dirs)
196+
num_units = hidden_size = h_shape[2]
197+
input_size = x_shape[2]
198198
w_shape = [num_layers*num_dirs, 3*hidden_size, input_size]
199199
w_shape_const = ctx.make_const(utils.make_name("w_shape"), np.array(w_shape, dtype=np.int64))
200200
r_shape = [num_layers*num_dirs, 3*hidden_size, hidden_size]
@@ -208,44 +208,44 @@ def version_11(cls, ctx, node, **kwargs):
208208
r_end_const = ctx.make_const(utils.make_name("r_end"), np.array([r_end], dtype=np.int64))
209209
b_end = r_end + np.prod(b_shape)
210210
b_end_const = ctx.make_const(utils.make_name("b_end"), np.array([b_end], dtype=np.int64))
211-
def NM(nm):
211+
def Name(nm):
212212
return node.name + "_" + nm
213-
WS = [NM('W_' + str(i)) for i in range(num_layers*num_dirs)]
214-
RS = [NM('R_' + str(i)) for i in range(num_layers*num_dirs)]
215-
BS = [NM('B_' + str(i)) for i in range(num_layers*num_dirs)]
216-
HS = [NM('H_' + str(i)) for i in range(num_layers*num_dirs)]
217-
YHS = [NM('YH_' + str(i)) for i in range(num_layers*num_dirs)]
218-
W_flattened = ctx.make_node('Slice', [P, zero_const.output[0], w_end_const.output[0]])
219-
R_flattened = ctx.make_node('Slice', [P, w_end_const.output[0], r_end_const.output[0]])
220-
B_flattened = ctx.make_node('Slice', [P, r_end_const.output[0], b_end_const.output[0]])
221-
W = utils.make_name('W')
222-
R = utils.make_name('R')
223-
B = utils.make_name('B')
224-
ctx.make_node('Reshape', [W_flattened.output[0], w_shape_const.output[0]], outputs=[W])
225-
ctx.make_node('Reshape', [R_flattened.output[0], r_shape_const.output[0]], outputs=[R])
226-
ctx.make_node('Reshape', [B_flattened.output[0], b_shape_const.output[0]], outputs=[B])
227-
ctx.make_node('Split', [W], outputs=WS)
228-
ctx.make_node('Split', [R], outputs=RS)
229-
ctx.make_node('Split', [B], outputs=BS)
230-
ctx.make_node('Split', [H], outputs=HS)
231-
XNF = XNB = X
213+
ws = [Name('W_' + str(i)) for i in range(num_layers*num_dirs)]
214+
rs = [Name('R_' + str(i)) for i in range(num_layers*num_dirs)]
215+
bs = [Name('B_' + str(i)) for i in range(num_layers*num_dirs)]
216+
hs = [Name('H_' + str(i)) for i in range(num_layers*num_dirs)]
217+
yhs = [Name('YH_' + str(i)) for i in range(num_layers*num_dirs)]
218+
w_flattened = ctx.make_node('Slice', [p, zero_const.output[0], w_end_const.output[0]])
219+
r_flattened = ctx.make_node('Slice', [p, w_end_const.output[0], r_end_const.output[0]])
220+
b_flattened = ctx.make_node('Slice', [p, r_end_const.output[0], b_end_const.output[0]])
221+
w = utils.make_name('W')
222+
r = utils.make_name('R')
223+
b = utils.make_name('B')
224+
ctx.make_node('Reshape', [w_flattened.output[0], w_shape_const.output[0]], outputs=[w])
225+
ctx.make_node('Reshape', [r_flattened.output[0], r_shape_const.output[0]], outputs=[r])
226+
ctx.make_node('Reshape', [b_flattened.output[0], b_shape_const.output[0]], outputs=[b])
227+
ctx.make_node('Split', [w], outputs=ws)
228+
ctx.make_node('Split', [r], outputs=rs)
229+
ctx.make_node('Split', [b], outputs=bs)
230+
ctx.make_node('Split', [h], outputs=hs)
231+
xnf = xnb = x
232232
for i in range(num_layers):
233233
suffix = '_' + str(i*num_dirs)
234-
ctx.make_node('GRU', [XNF, NM('W' + suffix), NM('R' + suffix), NM('B' + suffix), '', NM('H'+ suffix)],
235-
outputs=[NM('Y' + suffix), NM('YH' + suffix)],
234+
ctx.make_node('GRU', [xnf, Name('W' + suffix), Name('R' + suffix), Name('B' + suffix), '', Name('H'+ suffix)],
235+
outputs=[Name('Y' + suffix), Name('YH' + suffix)],
236236
attr={'direction': 'forward', 'hidden_size': num_units})
237-
XNF = NM(X + suffix)
238-
ctx.make_node('Squeeze', [NM('Y' + suffix)], outputs=[XNF], attr={'axes': [1]})
237+
xnf = Name(x + suffix)
238+
ctx.make_node('Squeeze', [Name('Y' + suffix)], outputs=[xnf], attr={'axes': [1]})
239239
if num_dirs == 2:
240240
suffix = '_' + str(i*2+1)
241-
ctx.make_node('GRU', [XNB, NM('W' + suffix), NM('R' + suffix), NM('B' + suffix), '', NM('H'+ suffix)],
242-
outputs=[NM('Y' + suffix), NM('YH' + suffix)],
241+
ctx.make_node('GRU', [xnb, Name('W' + suffix), Name('R' + suffix), Name('B' + suffix), '', Name('H'+ suffix)],
242+
outputs=[Name('Y' + suffix), Name('YH' + suffix)],
243243
attr={'direction': 'reverse', 'hidden_size': num_units})
244-
XNB = NM(X + suffix)
245-
ctx.make_node('Squeeze', [NM('Y' + suffix)], outputs=[XNB], attr={'axes': [1]})
244+
xnb = Name(x + suffix)
245+
ctx.make_node('Squeeze', [Name('Y' + suffix)], outputs=[xnb], attr={'axes': [1]})
246246
ctx.remove_node(node.name)
247247
if num_dirs == 2:
248-
ctx.make_node('Concat', [XNF, XNB], outputs=[node.output[0]], attr={'axis': -1})
248+
ctx.make_node('Concat', [xnf, xnb], outputs=[node.output[0]], attr={'axis': -1})
249249
else:
250-
ctx.make_node('Identity', [XNF], outputs=[node.output[0]])
251-
ctx.make_node('Concat', YHS, outputs=[node.output[1]], attr={'axis': 0})
250+
ctx.make_node('Identity', [xnf], outputs=[node.output[0]])
251+
ctx.make_node('Concat', yhs, outputs=[node.output[1]], attr={'axis': 0})

0 commit comments

Comments
 (0)