Skip to content

Commit ff6c7b2

Browse files
committed
fix pylint
1 parent d5637b7 commit ff6c7b2

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

tests/test_backend.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
import tensorflow as tf
1717

1818
from tensorflow.python.ops import lookup_ops
19+
from tensorflow.python.ops import init_ops
1920
from backend_test_base import Tf2OnnxBackendTestBase
2021
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
2122
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
2223
from tf2onnx import constants, utils
2324
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2425
from tf2onnx.tf_loader import is_tf2
25-
from tensorflow.python.ops import init_ops
2626

2727
# pylint: disable=missing-docstring,invalid-name,unused-argument,function-redefined,cell-var-from-loop
2828

@@ -2919,24 +2919,6 @@ def func(query_holder):
29192919
self._run_test_case(func, [_OUTPUT], {_INPUT: query}, constant_fold=False)
29202920
os.remove(filnm)
29212921

2922-
@check_opset_min_version(11, "GRU")
2923-
def test_cudnngru(self):
2924-
seq_length = 3
2925-
batch_size = 5
2926-
input_size = 2
2927-
num_layers = 2
2928-
num_units = 2
2929-
num_dirs = 2
2930-
initializer = init_ops.constant_initializer(0.5)
2931-
x = np.random.randint(0, 100, [seq_length, batch_size, input_size]).astype(np.float32)
2932-
h = np.random.randint(0, 100, [num_layers * num_dirs, batch_size, num_units]).astype(np.float32).reshape(
2933-
[num_layers * num_dirs, batch_size, num_units])
2934-
cudnngru = tf.contrib.cudnn_rnn.CudnnGRU(num_layers, num_units, 'linear_input', 'bidirectional',
2935-
kernel_initializer=initializer, bias_initializer=initializer)
2936-
cudnngru.build([seq_length, batch_size, input_size])
2937-
outputs = cudnngru.call(x, tuple([h]))
2938-
self.run_test_case({}, [], [outputs[0].name], rtol=1e-05, atol=1e-04)
2939-
29402922
@check_opset_min_version(11)
29412923
def test_matrix_diag_part(self):
29422924
input_vals = [
@@ -2951,6 +2933,26 @@ def func(input_holder):
29512933
for input_val in input_vals:
29522934
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
29532935

2936+
@check_opset_min_version(11, "GRU")
2937+
def test_cudnngru(self):
2938+
def func():
2939+
seq_length = 3
2940+
batch_size = 5
2941+
input_size = 2
2942+
num_layers = 2
2943+
num_units = 2
2944+
num_dirs = 2
2945+
initializer = init_ops.constant_initializer(0.5)
2946+
x = np.random.randint(0, 100, [seq_length, batch_size, input_size]).astype(np.float32)
2947+
h = np.random.randint(0, 100, [num_layers * num_dirs, batch_size, num_units]).astype(np.float32).reshape(
2948+
[num_layers * num_dirs, batch_size, num_units])
2949+
cudnngru = tf.contrib.cudnn_rnn.CudnnGRU(num_layers, num_units, 'linear_input', 'bidirectional',
2950+
kernel_initializer=initializer, bias_initializer=initializer)
2951+
cudnngru.build([seq_length, batch_size, input_size])
2952+
outputs = cudnngru.call(x, tuple([h]))
2953+
_ = tf.identity(outputs[0], name=_TFOUTPUT)
2954+
self.run_test_case(func, {}, [], [_OUTPUT], rtol=1e-05, atol=1e-04)
2955+
29542956

29552957
if __name__ == '__main__':
29562958
unittest_main()

tf2onnx/onnx_opset/rnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,14 @@ def NM(nm):
233233
suffix = '_' + str(i*num_dirs)
234234
ctx.make_node('GRU', [XNF, NM('W' + suffix), NM('R' + suffix), NM('B' + suffix), '', NM('H'+ suffix)],
235235
outputs=[NM('Y' + suffix), NM('YH' + suffix)],
236-
attr={'direction':'forward', 'hidden_size':num_units})
236+
attr={'direction': 'forward', 'hidden_size': num_units})
237237
XNF = NM(X + suffix)
238238
ctx.make_node('Squeeze', [NM('Y' + suffix)], outputs=[XNF], attr={'axes': [1]})
239239
if num_dirs == 2:
240240
suffix = '_' + str(i*2+1)
241241
ctx.make_node('GRU', [XNB, NM('W' + suffix), NM('R' + suffix), NM('B' + suffix), '', NM('H'+ suffix)],
242242
outputs=[NM('Y' + suffix), NM('YH' + suffix)],
243-
attr={'direction':'reverse', 'hidden_size':num_units})
243+
attr={'direction': 'reverse', 'hidden_size': num_units})
244244
XNB = NM(X + suffix)
245245
ctx.make_node('Squeeze', [NM('Y' + suffix)], outputs=[XNB], attr={'axes': [1]})
246246
ctx.remove_node(node.name)

0 commit comments

Comments
 (0)