Skip to content

Commit e15c179

Browse files
committed
add reverse support for dynamic axi
1 parent e944207 commit e15c179

File tree

3 files changed

+71
-41
lines changed

3 files changed

+71
-41
lines changed

tests/test_backend.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
2121
from tf2onnx import constants, utils
2222
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
23+
from tensorflow.python.ops import init_ops
2324

2425
# pylint: disable=missing-docstring,invalid-name,unused-argument
2526

@@ -3010,6 +3011,24 @@ def test_hashtable_lookup(self):
30103011
self._run_test_case([lookup_results.name], {_INPUT: query})
30113012
os.remove(filnm)
30123013

3014+
@check_opset_min_version(11, "GRU")
3015+
def test_cudnngru(self):
3016+
seq_length = 3
3017+
batch_size = 5
3018+
input_size = 2
3019+
num_layers = 2
3020+
num_units = 2
3021+
num_dirs = 2
3022+
initializer = init_ops.constant_initializer(0.5)
3023+
x = np.random.randint(0, 100, [seq_length, batch_size, input_size]).astype(np.float32)
3024+
h = np.random.randint(0, 100, [num_layers * num_dirs, batch_size, num_units]).astype(np.float32).reshape(
3025+
[num_layers * num_dirs, batch_size, num_units])
3026+
cudnngru = tf.contrib.cudnn_rnn.CudnnGRU(num_layers, num_units, 'linear_input', 'bidirectional',
3027+
kernel_initializer=initializer, bias_initializer=initializer)
3028+
cudnngru.build([seq_length, batch_size, input_size])
3029+
outputs = cudnngru.call(x, tuple([h]))
3030+
self.run_test_case({}, [], [outputs[0].name], rtol=1e-05, atol=1e-04)
3031+
30133032

30143033
if __name__ == '__main__':
30153034
unittest_main()

tf2onnx/onnx_opset/rnn.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,6 @@ def version_7(cls, ctx, node, **kwargs):
175175
class CudnnRNN:
176176
@classmethod
177177
def version_11(cls, ctx, node, **kwargs):
178-
#print ("CudnnRNN captured")
179-
#print (node.attr["direction"].s)
180-
#ii = input(os.getpid())
181-
#print ("input\n", node.input)
182-
#print ("\noutput\n", node.output)
183-
#print ("\nattr\n", node.attr)
184178
X = node.input[0]
185179
X_shape = ctx.get_shape(X)
186180
H = node.input[1]
@@ -190,6 +184,14 @@ def version_11(cls, ctx, node, **kwargs):
190184
node.attr["rnn_mode"].s == b"gru",
191185
"rnn mode other than gru are not supported yet"
192186
)
187+
utils.make_sure(
188+
node.attr["dropout"].f == 0,
189+
"dropout not supported yet"
190+
)
191+
utils.make_sure(
192+
node.attr["input_mode"].s == b"linear_input",
193+
"input mode must be linear input"
194+
)
193195
num_dirs = 1 if node.attr["direction"].s == b"unidirectional" else 2
194196
num_layers = int(H_shape[0]/num_dirs)
195197
num_units = hidden_size = H_shape[2]
@@ -217,26 +219,24 @@ def NM(nm):
217219
W_flattened = ctx.make_node('Slice', [P, zero_const.output[0], w_end_const.output[0]])
218220
R_flattened = ctx.make_node('Slice', [P, w_end_const.output[0], r_end_const.output[0]])
219221
B_flattened = ctx.make_node('Slice', [P, r_end_const.output[0], b_end_const.output[0]])
220-
#W = utils.make_name('W')
221-
#R = utils.make_name('R')
222-
#B = utils.make_name('B')
223-
W = ctx.make_node('Reshape', [W_flattened.output[0], w_shape_const.output[0]])
224-
R = ctx.make_node('Reshape', [R_flattened.output[0], r_shape_const.output[0]])
225-
B = ctx.make_node('Reshape', [B_flattened.output[0], b_shape_const.output[0]])
226-
ctx.make_node('Split', [W.output[0]], outputs = WS)
227-
ctx.make_node('Split', [R.output[0]], outputs = RS)
228-
ctx.make_node('Split', [B.output[0]], outputs = BS)
222+
W = utils.make_name('W')
223+
R = utils.make_name('R')
224+
B = utils.make_name('B')
225+
ctx.make_node('Reshape', [W_flattened.output[0], w_shape_const.output[0]], outputs=[W])
226+
ctx.make_node('Reshape', [R_flattened.output[0], r_shape_const.output[0]], outputs=[R])
227+
ctx.make_node('Reshape', [B_flattened.output[0], b_shape_const.output[0]], outputs=[B])
228+
ctx.make_node('Split', [W], outputs = WS)
229+
ctx.make_node('Split', [R], outputs = RS)
230+
ctx.make_node('Split', [B], outputs = BS)
229231
ctx.make_node('Split', [H], outputs = HS)
230232
XNF = XNB = X
231-
gru_nodes = []
232-
squeeze_nodes = []
233233
for i in range(num_layers):
234-
suffix = '_' + str(i*2)
235-
gru_nodes.append(ctx.make_node('GRU', [XNF, NM('W' + suffix), NM('R' + suffix), NM('B' + suffix), '', NM('H'+ suffix)],
236-
outputs = [NM('Y' + suffix), NM('YH' + suffix)],
237-
attr={'direction':'forward', 'hidden_size':num_units}))
234+
suffix = '_' + str(i*num_dirs)
235+
ctx.make_node('GRU', [XNF, NM('W' + suffix), NM('R' + suffix), NM('B' + suffix), '', NM('H'+ suffix)],
236+
outputs = [NM('Y' + suffix), NM('YH' + suffix)],
237+
attr={'direction':'forward', 'hidden_size':num_units})
238238
XNF = NM(X + suffix)
239-
squeeze_nodes.append(ctx.make_node('Squeeze', [NM('Y' + suffix)], outputs = [XNF], attr={'axes': [1]}))
239+
ctx.make_node('Squeeze', [NM('Y' + suffix)], outputs = [XNF], attr={'axes': [1]})
240240
if num_dirs == 2:
241241
suffix = '_' + str(i*2+1)
242242
ctx.make_node('GRU', [XNB, NM('W' + suffix), NM('R' + suffix), NM('B' + suffix), '', NM('H'+ suffix)],
@@ -249,5 +249,4 @@ def NM(nm):
249249
ctx.make_node('Concat', [XNF, XNB], outputs = [node.output[0]], attr={'axis': -1})
250250
else:
251251
identity_0 = ctx.make_node('Identity', [XNF], outputs = [node.output[0]])
252-
concat_0 = ctx.make_node('Concat', YHS, outputs = [node.output[1]], attr={'axis': 0})
253-
#print ("Done")
252+
concat_0 = ctx.make_node('Concat', YHS, outputs = [node.output[1]], attr={'axis': 0})

tf2onnx/onnx_opset/tensor.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,23 +1647,35 @@ def version_10(cls, ctx, node, **kwargs):
16471647
inputs = [new_node.output[0]]
16481648

16491649
# Add a Constant node (seq_len) for ReverseSequence.
1650-
1651-
# Index 1 for the shape should not return 0
1652-
# since the input must have rank >= 2.
1653-
rs_batch_size = ctx.get_shape(inputs[-1])[1]
1654-
1655-
# Make sure rs_batch_size and input_shape[axis] are not -1 each
1656-
utils.make_sure(input_shape[axis] is not -1 \
1657-
, "shape of axis {} is unknown".format(axis))
1658-
utils.make_sure(rs_batch_size is not -1 \
1659-
, "ReverseSequence batch size for axis {} is unknown".format(axis))
1660-
1661-
seq_list = [input_shape[axis]] * rs_batch_size
1662-
seq_array = np.asarray(seq_list, dtype=np.int64) # dtype should be int64
1663-
1664-
const_seq_name = utils.make_name(const_name_root)
1665-
new_node = ctx.make_const(name=const_seq_name, np_val=seq_array)
1666-
inputs.append(new_node.output[0])
1650+
if ctx.opset >= 11:
1651+
batch_shape = ctx.make_node("Shape", [inputs[-1]])
1652+
const_one = ctx.make_const(utils.make_name(node.name + "_const_one"), np.array([1], dtype=np.int64))
1653+
const_two = ctx.make_const(utils.make_name(node.name + "_const_two"), np.array([2], dtype=np.int64))
1654+
batch_size = ctx.make_node("Slice",
1655+
[batch_shape.output[0], const_one.output[0], const_two.output[0]])
1656+
input_shape = ctx.make_node("Shape", [node.input[0]])
1657+
const_axis = ctx.make_const(utils.make_name(node.name + "_const_axis"),
1658+
np.array([axis], dtype=np.int64))
1659+
const_axis_next = ctx.make_const(utils.make_name(node.name + "_const_axis_next"),
1660+
np.array([axis + 1], dtype=np.int64))
1661+
input_axis = ctx.make_node("Slice",
1662+
[input_shape.output[0], const_axis.output[0], const_axis_next.output[0]])
1663+
seq_array = ctx.make_node("Expand", [input_axis.output[0], batch_size.output[0]])
1664+
inputs.append(seq_array.output[0])
1665+
else:
1666+
# Index 1 for the shape should not return 0
1667+
# since the input must have rank >= 2.
1668+
rs_batch_size = ctx.get_shape(inputs[-1])[1]
1669+
# Make sure rs_batch_size and input_shape[axis] are not -1 each
1670+
utils.make_sure(input_shape[axis] is not -1 \
1671+
, "shape of axis {} is unknown".format(axis))
1672+
utils.make_sure(rs_batch_size is not -1 \
1673+
, "ReverseSequence batch size for axis {} is unknown".format(axis))
1674+
seq_list = [input_shape[axis]] * rs_batch_size
1675+
seq_array = np.asarray(seq_list, dtype=np.int64) # dtype should be int64
1676+
const_seq_name = utils.make_name(const_name_root)
1677+
new_node = ctx.make_const(name=const_seq_name, np_val=seq_array)
1678+
inputs.append(new_node.output[0])
16671679

16681680
# Add a ReverseSequence node.
16691681

0 commit comments

Comments
 (0)