Skip to content

Commit c5e560d

Browse files
Merge pull request #819 from RandySheriffH/rashuai/CudnnGRU
Rashuai/cudnn gru
2 parents e467d58 + 92374ac commit c5e560d

File tree

3 files changed

+165
-18
lines changed

3 files changed

+165
-18
lines changed

tests/test_cudnn.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""Unit Tests for cudnn."""
5+
6+
from __future__ import absolute_import
7+
from __future__ import division
8+
from __future__ import print_function
9+
from __future__ import unicode_literals
10+
11+
import numpy as np
12+
import tensorflow as tf
13+
14+
from tensorflow.python.ops import init_ops
15+
from backend_test_base import Tf2OnnxBackendTestBase
16+
from common import check_tf_max_version, skip_tf_cpu, check_opset_min_version, unittest_main
17+
18+
19+
class CudnnTests(Tf2OnnxBackendTestBase):
20+
""" test cudnn cases """
21+
@check_tf_max_version("1.15.0", "not supported in tf-2.0")
22+
@skip_tf_cpu("only tf_gpu can run CudnnGPU")
23+
@check_opset_min_version(10, "CudnnGRU")
24+
def test_cudnngru(self):
25+
""" test contrib cudnn gru """
26+
seq_length = 3
27+
batch_size = 5
28+
input_size = 2
29+
num_layers = 2
30+
num_units = 2
31+
num_dirs = 2
32+
x_val = np.random.randint(0, 100, [seq_length, batch_size, input_size]).astype(np.float32)
33+
h_val = np.random.randint(0, 100, [num_layers * num_dirs, batch_size, num_units]).astype(np.float32).reshape(
34+
[num_layers * num_dirs, batch_size, num_units])
35+
36+
def func(x, h):
37+
initializer = init_ops.constant_initializer(0.5)
38+
cudnngru = tf.contrib.cudnn_rnn.CudnnGRU(num_layers, num_units, 'linear_input', 'bidirectional',
39+
kernel_initializer=initializer, bias_initializer=initializer)
40+
cudnngru.build([seq_length, batch_size, input_size])
41+
outputs = cudnngru.call(x, tuple([h]))
42+
_ = tf.identity(outputs[0], name='output')
43+
44+
feed_dict = {"input_1:0": x_val, "input_2:0": h_val}
45+
input_names_with_port = ["input_1:0", "input_2:0"]
46+
output_names_with_port = ["output:0"]
47+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-05, atol=1e-04)
48+
49+
50+
if __name__ == '__main__':
51+
unittest_main()

tf2onnx/onnx_opset/rnn.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from __future__ import unicode_literals
1111

1212
import logging
13-
1413
import numpy as np
1514
from tf2onnx import utils
1615
from tf2onnx.handler import tf_op
@@ -169,3 +168,88 @@ def replace_output(old_output, new_output):
169168
@classmethod
170169
def version_7(cls, ctx, node, **kwargs):
171170
cls.version_1(ctx, node, **kwargs)
171+
172+
173+
@tf_op("CudnnRNN")
174+
class CudnnRNN:
175+
@classmethod
176+
def version_10(cls, ctx, node, **kwargs):
177+
x = node.input[0]
178+
x_shape = ctx.get_shape(x)
179+
h = node.input[1]
180+
h_shape = ctx.get_shape(h)
181+
p = node.input[3]
182+
utils.make_sure(
183+
node.attr["rnn_mode"].s == b"gru",
184+
"rnn mode other than gru are not supported yet"
185+
)
186+
utils.make_sure(
187+
node.attr["dropout"].f == 0,
188+
"dropout not supported yet"
189+
)
190+
utils.make_sure(
191+
node.attr["input_mode"].s == b"linear_input",
192+
"input mode must be linear input"
193+
)
194+
num_dirs = 1 if node.attr["direction"].s == b"unidirectional" else 2
195+
num_layers = int(h_shape[0] / num_dirs)
196+
num_units = hidden_size = h_shape[2]
197+
input_size = x_shape[2]
198+
w_shape = [num_layers * num_dirs, 3 * hidden_size, input_size]
199+
w_shape_const = ctx.make_const(utils.make_name("w_shape"), np.array(w_shape, dtype=np.int64))
200+
r_shape = [num_layers * num_dirs, 3 * hidden_size, hidden_size]
201+
r_shape_const = ctx.make_const(utils.make_name("r_shape"), np.array(r_shape, dtype=np.int64))
202+
b_shape = [num_layers * num_dirs, 6 * hidden_size]
203+
b_shape_const = ctx.make_const(utils.make_name("b_shape"), np.array(b_shape, dtype=np.int64))
204+
zero_const = ctx.make_const(utils.make_name("zero"), np.array([0], dtype=np.int64))
205+
w_end = np.prod(w_shape)
206+
w_end_const = ctx.make_const(utils.make_name("w_end"), np.array([w_end], dtype=np.int64))
207+
r_end = w_end + np.prod(r_shape)
208+
r_end_const = ctx.make_const(utils.make_name("r_end"), np.array([r_end], dtype=np.int64))
209+
b_end = r_end + np.prod(b_shape)
210+
b_end_const = ctx.make_const(utils.make_name("b_end"), np.array([b_end], dtype=np.int64))
211+
212+
def name(nm):
213+
return node.name + "_" + nm
214+
215+
ws = [name('W_' + str(i)) for i in range(num_layers * num_dirs)]
216+
rs = [name('R_' + str(i)) for i in range(num_layers * num_dirs)]
217+
bs = [name('B_' + str(i)) for i in range(num_layers * num_dirs)]
218+
hs = [name('H_' + str(i)) for i in range(num_layers * num_dirs)]
219+
yhs = [name('YH_' + str(i)) for i in range(num_layers * num_dirs)]
220+
w_flattened = ctx.make_node('Slice', [p, zero_const.output[0], w_end_const.output[0]])
221+
r_flattened = ctx.make_node('Slice', [p, w_end_const.output[0], r_end_const.output[0]])
222+
b_flattened = ctx.make_node('Slice', [p, r_end_const.output[0], b_end_const.output[0]])
223+
w = utils.make_name('W')
224+
r = utils.make_name('R')
225+
b = utils.make_name('B')
226+
ctx.make_node('Reshape', [w_flattened.output[0], w_shape_const.output[0]], outputs=[w])
227+
ctx.make_node('Reshape', [r_flattened.output[0], r_shape_const.output[0]], outputs=[r])
228+
ctx.make_node('Reshape', [b_flattened.output[0], b_shape_const.output[0]], outputs=[b])
229+
ctx.make_node('Split', [w], outputs=ws)
230+
ctx.make_node('Split', [r], outputs=rs)
231+
ctx.make_node('Split', [b], outputs=bs)
232+
ctx.make_node('Split', [h], outputs=hs)
233+
xnf = xnb = x
234+
for i in range(num_layers):
235+
suffix = '_' + str(i * num_dirs)
236+
ctx.make_node('GRU',
237+
[xnf, name('W' + suffix), name('R' + suffix), name('B' + suffix), '', name('H' + suffix)],
238+
outputs=[name('Y' + suffix), name('YH' + suffix)],
239+
attr={'direction': 'forward', 'hidden_size': num_units})
240+
xnf = name(x + suffix)
241+
ctx.make_node('Squeeze', [name('Y' + suffix)], outputs=[xnf], attr={'axes': [1]})
242+
if num_dirs == 2:
243+
suffix = '_' + str(i * 2 + 1)
244+
ctx.make_node('GRU',
245+
[xnb, name('W' + suffix), name('R' + suffix), name('B' + suffix), '', name('H' + suffix)],
246+
outputs=[name('Y' + suffix), name('YH' + suffix)],
247+
attr={'direction': 'reverse', 'hidden_size': num_units})
248+
xnb = name(x + suffix)
249+
ctx.make_node('Squeeze', [name('Y' + suffix)], outputs=[xnb], attr={'axes': [1]})
250+
ctx.remove_node(node.name)
251+
if num_dirs == 2:
252+
ctx.make_node('Concat', [xnf, xnb], outputs=[node.output[0]], attr={'axis': -1})
253+
else:
254+
ctx.make_node('Identity', [xnf], outputs=[node.output[0]])
255+
ctx.make_node('Concat', yhs, outputs=[node.output[1]], attr={'axis': 0})

tf2onnx/onnx_opset/tensor.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,23 +1653,35 @@ def version_10(cls, ctx, node, **kwargs):
16531653
inputs = [new_node.output[0]]
16541654

16551655
# Add a Constant node (seq_len) for ReverseSequence.
1656-
1657-
# Index 1 for the shape should not return 0
1658-
# since the input must have rank >= 2.
1659-
rs_batch_size = ctx.get_shape(inputs[-1])[1]
1660-
1661-
# Make sure rs_batch_size and input_shape[axis] are not -1 each
1662-
utils.make_sure(input_shape[axis] is not -1 \
1663-
, "shape of axis {} is unknown".format(axis))
1664-
utils.make_sure(rs_batch_size is not -1 \
1665-
, "ReverseSequence batch size for axis {} is unknown".format(axis))
1666-
1667-
seq_list = [input_shape[axis]] * rs_batch_size
1668-
seq_array = np.asarray(seq_list, dtype=np.int64) # dtype should be int64
1669-
1670-
const_seq_name = utils.make_name(const_name_root)
1671-
new_node = ctx.make_const(name=const_seq_name, np_val=seq_array)
1672-
inputs.append(new_node.output[0])
1656+
if ctx.opset >= 11:
1657+
batch_shape = ctx.make_node("Shape", [inputs[-1]])
1658+
const_one = ctx.make_const(utils.make_name(node.name + "_const_one"), np.array([1], dtype=np.int64))
1659+
const_two = ctx.make_const(utils.make_name(node.name + "_const_two"), np.array([2], dtype=np.int64))
1660+
batch_size = ctx.make_node("Slice",
1661+
[batch_shape.output[0], const_one.output[0], const_two.output[0]])
1662+
input_shape = ctx.make_node("Shape", [node.input[0]])
1663+
const_axis = ctx.make_const(utils.make_name(node.name + "_const_axis"),
1664+
np.array([axis], dtype=np.int64))
1665+
const_axis_next = ctx.make_const(utils.make_name(node.name + "_const_axis_next"),
1666+
np.array([axis + 1], dtype=np.int64))
1667+
input_axis = ctx.make_node("Slice",
1668+
[input_shape.output[0], const_axis.output[0], const_axis_next.output[0]])
1669+
seq_array = ctx.make_node("Expand", [input_axis.output[0], batch_size.output[0]])
1670+
inputs.append(seq_array.output[0])
1671+
else:
1672+
# Index 1 for the shape should not return 0
1673+
# since the input must have rank >= 2.
1674+
rs_batch_size = ctx.get_shape(inputs[-1])[1]
1675+
# Make sure rs_batch_size and input_shape[axis] are not -1 each
1676+
utils.make_sure(input_shape[axis] is not -1 \
1677+
, "shape of axis {} is unknown".format(axis))
1678+
utils.make_sure(rs_batch_size is not -1 \
1679+
, "ReverseSequence batch size for axis {} is unknown".format(axis))
1680+
seq_list = [input_shape[axis]] * rs_batch_size
1681+
seq_array = np.asarray(seq_list, dtype=np.int64) # dtype should be int64
1682+
const_seq_name = utils.make_name(const_name_root)
1683+
new_node = ctx.make_const(name=const_seq_name, np_val=seq_array)
1684+
inputs.append(new_node.output[0])
16731685

16741686
# Add a ReverseSequence node.
16751687

0 commit comments

Comments
 (0)