Skip to content

Commit ceb6eb7

Browse files
authored
Merge pull request #483 from nbcsm/rs10
opset10 - ReverseSequence
2 parents dd3a410 + fda2265 commit ceb6eb7

File tree

5 files changed

+136
-175
lines changed

5 files changed

+136
-175
lines changed

tests/test_backend.py

Lines changed: 23 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,6 @@ def test_equal(self):
724724
_ = tf.identity(mi, name=_TFOUTPUT)
725725
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
726726

727-
728727
def test_sequeeze_no_axis_specified(self):
729728
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2, 1))
730729
x = tf.placeholder(tf.float32, [2, 2, 1], name=_TFINPUT)
@@ -982,14 +981,14 @@ def test_slice1(self):
982981
self._run_test_case([_OUTPUT], {_INPUT: x_val})
983982

984983
def test_split(self):
985-
x_val = np.linspace(1.0, 5 * 30.0, 5 * 30).astype(np.float32).reshape(5, 30)
984+
x_val = np.linspace(1.0, 5 * 30.0, 5 * 30).astype(np.float32).reshape((5, 30))
986985
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
987986
x_, _, _ = tf.split(x0, [4, 15, 11], 1)
988987
_ = tf.identity(x_, name=_TFOUTPUT)
989988
self._run_test_case([_OUTPUT], {_INPUT: x_val})
990989

991990
def test_split_with_more_outputs(self):
992-
x_val = np.linspace(1.0, 5 * 30.0, 5 * 30).astype(np.float32).reshape(5, 30)
991+
x_val = np.linspace(1.0, 5 * 30.0, 5 * 30).astype(np.float32).reshape((5, 30))
993992
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
994993
_, _, _ = tf.split(x0, [4, 15, 11], 1, name="split_test")
995994
self._run_test_case(["split_test:0", "split_test:1", "split_test:2"], {_INPUT: x_val})
@@ -1418,36 +1417,36 @@ def test_addn(self):
14181417

14191418
@skip_caffe2_backend("multiple dims not supported")
14201419
def test_strided_slice1(self):
1421-
x_val = np.arange(3 * 2 * 3).astype("float32").reshape(3, 2, 3)
1420+
x_val = np.arange(3 * 2 * 3).astype("float32").reshape((3, 2, 3))
14221421
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
14231422
x_ = tf.strided_slice(x, [1, 0, 0], [2, 1, 3], [1, 1, 1])
14241423
_ = tf.identity(x_, name=_TFOUTPUT)
14251424
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14261425

14271426
def test_strided_slice2(self):
1428-
x_val = np.arange(3 * 2 * 3).astype("float32").reshape(3, 2, 3)
1427+
x_val = np.arange(3 * 2 * 3).astype("float32").reshape((3, 2, 3))
14291428
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
14301429
x_ = tf.strided_slice(x, [1, 0, 0], [2, 2, 3], [1, 1, 1])
14311430
_ = tf.identity(x_, name=_TFOUTPUT)
14321431
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14331432

14341433
def test_strided_slice3(self):
1435-
x_val = np.arange(3 * 2 * 3).astype("float32").reshape(3, 2, 3)
1434+
x_val = np.arange(3 * 2 * 3).astype("float32").reshape((3, 2, 3))
14361435
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
14371436
x_ = x[1:]
14381437
_ = tf.identity(x_, name=_TFOUTPUT)
14391438
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14401439

14411440
def test_strided_slice4(self):
1442-
x_val = np.arange(3 * 2 * 3).astype("float32").reshape(3, 2, 3)
1441+
x_val = np.arange(3 * 2 * 3).astype("float32").reshape((3, 2, 3))
14431442
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
14441443
x_ = x[:2]
14451444
_ = tf.identity(x_, name=_TFOUTPUT)
14461445
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14471446

14481447
@skip_caffe2_backend("multiple dims not supported")
14491448
def test_strided_slice5(self):
1450-
x_val = np.arange(3 * 2 * 3).astype("float32").reshape(3, 2, 3)
1449+
x_val = np.arange(3 * 2 * 3).astype("float32").reshape((3, 2, 3))
14511450
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
14521451
x_ = x[:2, 0:1, 1:]
14531452
_ = tf.identity(x_, name=_TFOUTPUT)
@@ -1457,15 +1456,15 @@ def test_strided_slice5(self):
14571456
def test_strided_slice6(self):
14581457
# example from here:
14591458
# https://www.tensorflow.org/versions/r1.0/api_docs/cc/class/tensorflow/ops/strided-slice
1460-
x_val = np.arange(5 * 6).astype("float32").reshape(5, 6)
1459+
x_val = np.arange(5 * 6).astype("float32").reshape((5, 6))
14611460
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
14621461
x_ = x[2, :]
14631462
_ = tf.identity(x_, name=_TFOUTPUT)
14641463
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14651464

14661465
@skip_caffe2_backend("multiple dims not supported")
14671466
def test_strided_slice7(self):
1468-
x_val = np.arange(5 * 6).astype("float32").reshape(5, 6)
1467+
x_val = np.arange(5 * 6).astype("float32").reshape((5, 6))
14691468

14701469
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
14711470
x_ = tf.strided_slice(x, [0, 1], [3, 4], [1, 1], begin_mask=2)
@@ -1684,19 +1683,17 @@ def test_erf(self):
16841683
_ = tf.identity(x_, name=_TFOUTPUT)
16851684
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=0.01)
16861685

1687-
def _test_reverse_sequence_batch_major(self, extra_opset=None):
1688-
process_args = {}
1689-
if extra_opset is not None:
1690-
process_args["extra_opset"] = [extra_opset]
1691-
1686+
@check_opset_min_version(8, "Scan")
1687+
@skip_opset(9, "ReverseSequence")
1688+
def test_reverse_sequence_batch_major(self):
16921689
x_val = np.array([[[1, 2, 3], [4, 5, 6], [0, 0, 0]],
16931690
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
16941691
[[1, 2, 3], [0, 0, 0], [0, 0, 0]]],
16951692
dtype=np.float32)
16961693
x = tf.placeholder(tf.float32, [None, 3, 3], name=_TFINPUT)
16971694
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[2, 3, 1])
16981695
_ = tf.identity(x_, name=_TFOUTPUT)
1699-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1696+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17001697
tf.reset_default_graph()
17011698

17021699
x_val = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3],
@@ -1707,70 +1704,45 @@ def _test_reverse_sequence_batch_major(self, extra_opset=None):
17071704
x = tf.placeholder(tf.float32, [None, 3], name=_TFINPUT)
17081705
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[3] * 9)
17091706
_ = tf.identity(x_, name=_TFOUTPUT)
1710-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1707+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17111708
tf.reset_default_graph()
17121709

17131710
x_val_shape = [5, 5, 7, 8, 9]
17141711
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
17151712
x = tf.placeholder(tf.float32, [None, 5, 7, 8, 9], name=_TFINPUT)
17161713
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[5, 5, 5, 5, 5])
17171714
_ = tf.identity(x_, name=_TFOUTPUT)
1718-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1719-
1720-
def _test_reverse_sequence_time_major(self, extra_opset=None):
1721-
process_args = {}
1722-
if extra_opset is not None:
1723-
process_args["extra_opset"] = [extra_opset]
1715+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17241716

1717+
@check_opset_min_version(8, "Scan")
1718+
@skip_opset(9, "ReverseSequence")
1719+
def test_reverse_sequence_time_major(self):
17251720
x_val = np.array([[[1, 2, 3], [1, 2, 3], [1, 2, 3]],
17261721
[[4, 5, 6], [4, 5, 6], [0, 0, 0]],
1727-
[[0, 0, 0], [7, 8, 9], [0, 0, 0]]
1728-
],
1722+
[[0, 0, 0], [7, 8, 9], [0, 0, 0]]],
17291723
dtype=np.float32)
17301724
x = tf.placeholder(tf.float32, [3, None, 3], name=_TFINPUT)
17311725
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[2, 3, 1])
17321726
_ = tf.identity(x_, name=_TFOUTPUT)
1733-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1727+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17341728
tf.reset_default_graph()
17351729

17361730
x_val = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3],
17371731
[4, 5, 6], [4, 5, 6], [1, 1, 1],
1738-
[0, 0, 0], [7, 8, 9], [0, 0, 0]
1739-
],
1732+
[0, 0, 0], [7, 8, 9], [0, 0, 0]],
17401733
dtype=np.float32)
17411734
x = tf.placeholder(tf.float32, [9, None], name=_TFINPUT)
17421735
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[9, 9, 9])
17431736
_ = tf.identity(x_, name=_TFOUTPUT)
1744-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1737+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17451738
tf.reset_default_graph()
17461739

17471740
x_val_shape = [5, 5, 7, 8, 9]
17481741
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
17491742
x = tf.placeholder(tf.float32, [5, None, 7, 8, 9], name=_TFINPUT)
17501743
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[5, 5, 5, 5, 5])
17511744
_ = tf.identity(x_, name=_TFOUTPUT)
1752-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1753-
1754-
@check_opset_min_version(8, "Scan")
1755-
@skip_opset(9, "ReverseSequence")
1756-
def test_reverse_sequence_batch_major(self):
1757-
self._test_reverse_sequence_batch_major()
1758-
1759-
@check_opset_min_version(8, "Scan")
1760-
@skip_opset(9, "ReverseSequence")
1761-
def test_reverse_sequence_time_major(self):
1762-
self._test_reverse_sequence_time_major()
1763-
1764-
# only support onnxruntime with version larger than 0.4.0
1765-
@test_ms_domain()
1766-
@check_onnxruntime_min_version("0.4.0")
1767-
def test_ms_reverse_sequence_batch_major(self, extra_opset):
1768-
self._test_reverse_sequence_batch_major(extra_opset)
1769-
1770-
@test_ms_domain()
1771-
@check_onnxruntime_min_version("0.4.0")
1772-
def test_ms_reverse_sequence_time_major(self, extra_opset):
1773-
self._test_reverse_sequence_time_major(extra_opset)
1745+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17741746

17751747
@check_opset_min_version(8, "where")
17761748
def test_where(self):

tests/test_custom_rnncell.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,11 @@ def call(self, inputs, state):
416416
input_dim = inputs.get_shape()[-1]
417417
assert input_dim is not None, "input dimension must be defined"
418418
# W = tf.get_variable(name="W", shape=[input_dim, 3 * self._num_units], dtype=tf.float32)
419-
W = np.arange(30.0, dtype=np.float32).reshape(2, 15)
419+
W = np.arange(30.0, dtype=np.float32).reshape((2, 15))
420420
# U = tf.get_variable(name='U', shape=[self._num_units, 3 * self._num_units], dtype=tf.float32)
421-
U = np.arange(75.0, dtype=np.float32).reshape(5, 15)
421+
U = np.arange(75.0, dtype=np.float32).reshape((5, 15))
422422
# b = tf.get_variable(name='b', shape=[1, 3 * self._num_units], dtype=tf.float32)
423-
b = np.arange(15.0, dtype=np.float32).reshape(1, 15)
423+
b = np.arange(15.0, dtype=np.float32).reshape((1, 15))
424424

425425
xw = tf.split(tf.matmul(inputs, W) + b, 3, 1)
426426
hu = tf.split(tf.matmul(state, U), 3, 1)

tf2onnx/custom_opsets/ms.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -36,41 +36,3 @@ def version_1(cls, ctx, node, **kwargs):
3636
utils.make_sure(dtype is not None, "Tidx of %s is None", node.name)
3737
ctx.remove_node(node.name)
3838
make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], node.name, shape, dtype)
39-
40-
41-
@tf_op("ReverseSequence", domain=constants.MICROSOFT_DOMAIN)
42-
class ReverseSequence:
43-
@classmethod
44-
def version_1(cls, ctx, node, **kwargs):
45-
"""ReverseSequence"""
46-
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
47-
# T output = ReverseSequence(T input, int32 seqence_lens, @int time_axis, @int batch_axis)
48-
seq_dim = node.get_attr("seq_dim")
49-
utils.make_sure(seq_dim is not None, "sequence dim must be given in {}".format(node.name))
50-
seq_dim = seq_dim.i
51-
batch_dim = node.get_attr("batch_dim")
52-
if batch_dim is not None:
53-
batch_dim = batch_dim.i
54-
else:
55-
batch_dim = 0
56-
57-
output_dtypes = node.output_dtypes
58-
output_shapes = node.output_shapes
59-
ctx.remove_node(node.name)
60-
node = ctx.make_node(
61-
"ReverseSequence",
62-
node.input,
63-
outputs=node.output,
64-
shapes=output_shapes,
65-
dtypes=output_dtypes,
66-
domain=constants.MICROSOFT_DOMAIN,
67-
attr={"time_axis": seq_dim, "batch_axis": batch_dim}
68-
)
69-
70-
seq_len_dtype = ctx.get_dtype(node.input[1])
71-
utils.make_sure(seq_len_dtype is not None, "dtype of {} is None".format(node.input[1]))
72-
target_dtype = TensorProto.INT32
73-
if seq_len_dtype != target_dtype:
74-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=target_dtype)
75-
ctx.copy_shape(cast_node.input[0], cast_node.output[0])
76-
ctx.set_dtype(cast_node.output[0], target_dtype)

tf2onnx/onnx_opset/controlflow.py

Lines changed: 1 addition & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@
1313

1414
import numpy as np
1515

16-
from onnx import onnx_pb
1716
from onnx.onnx_pb import TensorProto
1817
from tf2onnx import utils
19-
from tf2onnx.onnx_opset.nn import spatial_map
2018
from tf2onnx.handler import tf_op
2119
from tf2onnx.utils import make_sure
2220

23-
2421
logger = logging.getLogger(__name__)
2522

23+
2624
# pylint: disable=unused-argument,missing-docstring
2725

2826
def get_inputs_for_current_iteration(g, input_id, iter_index):
@@ -241,78 +239,6 @@ def version_7(cls, ctx, node, **kwargs):
241239
pass
242240

243241

244-
@tf_op("ReverseSequence")
245-
class ReverseSequence:
246-
@classmethod
247-
def version_8(cls, ctx, node, **kwargs):
248-
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
249-
# T output = Scan(int64 sequence_lens, variadic initial_state_and_scan_inputs, @graph body,
250-
# @ints directions,@int num_scan_inputs)
251-
seq_dim = node.get_attr("seq_dim")
252-
batch_dim = node.get_attr("batch_dim")
253-
batch_major = seq_dim.i == 1 and (batch_dim or batch_dim.i == 0)
254-
time_major = batch_dim.i == 1 and (seq_dim or seq_dim.i == 0)
255-
perm_val = None
256-
257-
if not batch_major and not time_major:
258-
error_msg = "unsupported attributes, seq_dim:{}, batch_dim:{}".format(seq_dim, batch_dim)
259-
raise ValueError(error_msg)
260-
261-
if time_major:
262-
old_shape = ctx.get_shape(node.input[0])
263-
old_dtype = ctx.get_dtype(node.input[0])
264-
perm_val = [1, 0]
265-
rank = len(old_shape)
266-
utils.make_sure(rank >= 2, "rank of reverse_sequence input {} is at least 2".format(node.input[0]))
267-
perm_val += list(range(2, rank))
268-
trans_node = ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=perm_val)
269-
new_shape = spatial_map(old_shape, perm_val)
270-
ctx.set_shape(trans_node.output[0], new_shape)
271-
ctx.set_dtype(trans_node.output[0], old_dtype)
272-
273-
# handle batch_major input
274-
node.type = "Scan"
275-
node.set_attr("num_scan_inputs", 1)
276-
input_dtype = ctx.get_dtype(node.input[0])
277-
input_shape = ctx.get_shape(node.input[0])
278-
279-
g = ctx.create_new_graph_with_same_config()
280-
g.parent_graph = ctx
281-
g.add_graph_input('X', input_dtype, input_shape[2:])
282-
g.make_node('Identity', ['X'], outputs=['Y'])
283-
g.add_graph_output('Y', input_dtype, input_shape[2:])
284-
285-
node.set_body_graph_as_attr("body", g)
286-
node.set_attr("directions", [1]) # reverse the scan input
287-
288-
seq_len_dtype = ctx.get_dtype(node.input[1])
289-
if seq_len_dtype != onnx_pb.TensorProto.INT64:
290-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1])
291-
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
292-
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
293-
ctx.copy_shape(node.input[1], cast_node.output[0])
294-
295-
if time_major:
296-
# get back to time_major
297-
op_name = utils.make_name(node.name)
298-
trans_back_node = ctx.insert_new_node_on_output("Transpose", node.output[0],
299-
name=op_name, perm=perm_val)
300-
ctx.copy_dtype(node.output[0], trans_back_node.output[0])
301-
302-
tmp = node.input[0]
303-
node.input[0] = node.input[1]
304-
node.input[1] = tmp
305-
306-
@classmethod
307-
def version_9(cls, ctx, node, **kwargs):
308-
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
309-
# we cannot easily construct reverse_sequence equivalence in opset 9, so we will not support it
310-
# here. Actually using loops to do that is kind of meaningless since there will be performance
311-
# issue there for sure.
312-
raise NotImplementedError("ReverseSequence is not supported to convert in OPSET 9,"
313-
" if possible please try using OPSET 8 instead.")
314-
315-
316242
@tf_op("Range")
317243
class Range:
318244
@classmethod

0 commit comments

Comments
 (0)