Skip to content

Commit fda2265

Browse files
committed
migrate ms ReverseSequence for onnx opset 10
1 parent 53a39ad commit fda2265

File tree

4 files changed

+124
-162
lines changed

4 files changed

+124
-162
lines changed

tests/test_backend.py

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,19 +1683,17 @@ def test_erf(self):
16831683
_ = tf.identity(x_, name=_TFOUTPUT)
16841684
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=0.01)
16851685

1686-
def _test_reverse_sequence_batch_major(self, extra_opset=None):
1687-
process_args = {}
1688-
if extra_opset is not None:
1689-
process_args["extra_opset"] = [extra_opset]
1690-
1686+
@check_opset_min_version(8, "Scan")
1687+
@skip_opset(9, "ReverseSequence")
1688+
def test_reverse_sequence_batch_major(self):
16911689
x_val = np.array([[[1, 2, 3], [4, 5, 6], [0, 0, 0]],
16921690
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
16931691
[[1, 2, 3], [0, 0, 0], [0, 0, 0]]],
16941692
dtype=np.float32)
16951693
x = tf.placeholder(tf.float32, [None, 3, 3], name=_TFINPUT)
16961694
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[2, 3, 1])
16971695
_ = tf.identity(x_, name=_TFOUTPUT)
1698-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1696+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
16991697
tf.reset_default_graph()
17001698

17011699
x_val = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3],
@@ -1706,70 +1704,45 @@ def _test_reverse_sequence_batch_major(self, extra_opset=None):
17061704
x = tf.placeholder(tf.float32, [None, 3], name=_TFINPUT)
17071705
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[3] * 9)
17081706
_ = tf.identity(x_, name=_TFOUTPUT)
1709-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1707+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17101708
tf.reset_default_graph()
17111709

17121710
x_val_shape = [5, 5, 7, 8, 9]
17131711
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
17141712
x = tf.placeholder(tf.float32, [None, 5, 7, 8, 9], name=_TFINPUT)
17151713
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[5, 5, 5, 5, 5])
17161714
_ = tf.identity(x_, name=_TFOUTPUT)
1717-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1718-
1719-
def _test_reverse_sequence_time_major(self, extra_opset=None):
1720-
process_args = {}
1721-
if extra_opset is not None:
1722-
process_args["extra_opset"] = [extra_opset]
1715+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17231716

1717+
@check_opset_min_version(8, "Scan")
1718+
@skip_opset(9, "ReverseSequence")
1719+
def test_reverse_sequence_time_major(self):
17241720
x_val = np.array([[[1, 2, 3], [1, 2, 3], [1, 2, 3]],
17251721
[[4, 5, 6], [4, 5, 6], [0, 0, 0]],
1726-
[[0, 0, 0], [7, 8, 9], [0, 0, 0]]
1727-
],
1722+
[[0, 0, 0], [7, 8, 9], [0, 0, 0]]],
17281723
dtype=np.float32)
17291724
x = tf.placeholder(tf.float32, [3, None, 3], name=_TFINPUT)
17301725
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[2, 3, 1])
17311726
_ = tf.identity(x_, name=_TFOUTPUT)
1732-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1727+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17331728
tf.reset_default_graph()
17341729

17351730
x_val = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3],
17361731
[4, 5, 6], [4, 5, 6], [1, 1, 1],
1737-
[0, 0, 0], [7, 8, 9], [0, 0, 0]
1738-
],
1732+
[0, 0, 0], [7, 8, 9], [0, 0, 0]],
17391733
dtype=np.float32)
17401734
x = tf.placeholder(tf.float32, [9, None], name=_TFINPUT)
17411735
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[9, 9, 9])
17421736
_ = tf.identity(x_, name=_TFOUTPUT)
1743-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1737+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17441738
tf.reset_default_graph()
17451739

17461740
x_val_shape = [5, 5, 7, 8, 9]
17471741
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
17481742
x = tf.placeholder(tf.float32, [5, None, 7, 8, 9], name=_TFINPUT)
17491743
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[5, 5, 5, 5, 5])
17501744
_ = tf.identity(x_, name=_TFOUTPUT)
1751-
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1752-
1753-
@check_opset_min_version(8, "Scan")
1754-
@skip_opset(9, "ReverseSequence")
1755-
def test_reverse_sequence_batch_major(self):
1756-
self._test_reverse_sequence_batch_major()
1757-
1758-
@check_opset_min_version(8, "Scan")
1759-
@skip_opset(9, "ReverseSequence")
1760-
def test_reverse_sequence_time_major(self):
1761-
self._test_reverse_sequence_time_major()
1762-
1763-
# only support onnxruntime with version larger than 0.4.0
1764-
@test_ms_domain()
1765-
@check_onnxruntime_min_version("0.4.0")
1766-
def test_ms_reverse_sequence_batch_major(self, extra_opset):
1767-
self._test_reverse_sequence_batch_major(extra_opset)
1768-
1769-
@test_ms_domain()
1770-
@check_onnxruntime_min_version("0.4.0")
1771-
def test_ms_reverse_sequence_time_major(self, extra_opset):
1772-
self._test_reverse_sequence_time_major(extra_opset)
1745+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
17731746

17741747
@check_opset_min_version(8, "where")
17751748
def test_where(self):

tf2onnx/custom_opsets/ms.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -36,41 +36,3 @@ def version_1(cls, ctx, node, **kwargs):
3636
utils.make_sure(dtype is not None, "Tidx of %s is None", node.name)
3737
ctx.remove_node(node.name)
3838
make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], node.name, shape, dtype)
39-
40-
41-
@tf_op("ReverseSequence", domain=constants.MICROSOFT_DOMAIN)
42-
class ReverseSequence:
43-
@classmethod
44-
def version_1(cls, ctx, node, **kwargs):
45-
"""ReverseSequence"""
46-
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
47-
# T output = ReverseSequence(T input, int32 seqence_lens, @int time_axis, @int batch_axis)
48-
seq_dim = node.get_attr("seq_dim")
49-
utils.make_sure(seq_dim is not None, "sequence dim must be given in {}".format(node.name))
50-
seq_dim = seq_dim.i
51-
batch_dim = node.get_attr("batch_dim")
52-
if batch_dim is not None:
53-
batch_dim = batch_dim.i
54-
else:
55-
batch_dim = 0
56-
57-
output_dtypes = node.output_dtypes
58-
output_shapes = node.output_shapes
59-
ctx.remove_node(node.name)
60-
node = ctx.make_node(
61-
"ReverseSequence",
62-
node.input,
63-
outputs=node.output,
64-
shapes=output_shapes,
65-
dtypes=output_dtypes,
66-
domain=constants.MICROSOFT_DOMAIN,
67-
attr={"time_axis": seq_dim, "batch_axis": batch_dim}
68-
)
69-
70-
seq_len_dtype = ctx.get_dtype(node.input[1])
71-
utils.make_sure(seq_len_dtype is not None, "dtype of {} is None".format(node.input[1]))
72-
target_dtype = TensorProto.INT32
73-
if seq_len_dtype != target_dtype:
74-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=target_dtype)
75-
ctx.copy_shape(cast_node.input[0], cast_node.output[0])
76-
ctx.set_dtype(cast_node.output[0], target_dtype)

tf2onnx/onnx_opset/controlflow.py

Lines changed: 1 addition & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@
1313

1414
import numpy as np
1515

16-
from onnx import onnx_pb
1716
from onnx.onnx_pb import TensorProto
1817
from tf2onnx import utils
19-
from tf2onnx.onnx_opset.nn import spatial_map
2018
from tf2onnx.handler import tf_op
2119
from tf2onnx.utils import make_sure
2220

23-
2421
logger = logging.getLogger(__name__)
2522

23+
2624
# pylint: disable=unused-argument,missing-docstring
2725

2826
def get_inputs_for_current_iteration(g, input_id, iter_index):
@@ -241,78 +239,6 @@ def version_7(cls, ctx, node, **kwargs):
241239
pass
242240

243241

244-
@tf_op("ReverseSequence")
245-
class ReverseSequence:
246-
@classmethod
247-
def version_8(cls, ctx, node, **kwargs):
248-
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
249-
# T output = Scan(int64 sequence_lens, variadic initial_state_and_scan_inputs, @graph body,
250-
# @ints directions,@int num_scan_inputs)
251-
seq_dim = node.get_attr("seq_dim")
252-
batch_dim = node.get_attr("batch_dim")
253-
batch_major = seq_dim.i == 1 and (batch_dim or batch_dim.i == 0)
254-
time_major = batch_dim.i == 1 and (seq_dim or seq_dim.i == 0)
255-
perm_val = None
256-
257-
if not batch_major and not time_major:
258-
error_msg = "unsupported attributes, seq_dim:{}, batch_dim:{}".format(seq_dim, batch_dim)
259-
raise ValueError(error_msg)
260-
261-
if time_major:
262-
old_shape = ctx.get_shape(node.input[0])
263-
old_dtype = ctx.get_dtype(node.input[0])
264-
perm_val = [1, 0]
265-
rank = len(old_shape)
266-
utils.make_sure(rank >= 2, "rank of reverse_sequence input {} is at least 2".format(node.input[0]))
267-
perm_val += list(range(2, rank))
268-
trans_node = ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=perm_val)
269-
new_shape = spatial_map(old_shape, perm_val)
270-
ctx.set_shape(trans_node.output[0], new_shape)
271-
ctx.set_dtype(trans_node.output[0], old_dtype)
272-
273-
# handle batch_major input
274-
node.type = "Scan"
275-
node.set_attr("num_scan_inputs", 1)
276-
input_dtype = ctx.get_dtype(node.input[0])
277-
input_shape = ctx.get_shape(node.input[0])
278-
279-
g = ctx.create_new_graph_with_same_config()
280-
g.parent_graph = ctx
281-
g.add_graph_input('X', input_dtype, input_shape[2:])
282-
g.make_node('Identity', ['X'], outputs=['Y'])
283-
g.add_graph_output('Y', input_dtype, input_shape[2:])
284-
285-
node.set_body_graph_as_attr("body", g)
286-
node.set_attr("directions", [1]) # reverse the scan input
287-
288-
seq_len_dtype = ctx.get_dtype(node.input[1])
289-
if seq_len_dtype != onnx_pb.TensorProto.INT64:
290-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1])
291-
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
292-
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
293-
ctx.copy_shape(node.input[1], cast_node.output[0])
294-
295-
if time_major:
296-
# get back to time_major
297-
op_name = utils.make_name(node.name)
298-
trans_back_node = ctx.insert_new_node_on_output("Transpose", node.output[0],
299-
name=op_name, perm=perm_val)
300-
ctx.copy_dtype(node.output[0], trans_back_node.output[0])
301-
302-
tmp = node.input[0]
303-
node.input[0] = node.input[1]
304-
node.input[1] = tmp
305-
306-
@classmethod
307-
def version_9(cls, ctx, node, **kwargs):
308-
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
309-
# we cannot easily construct reverse_sequence equivalence in opset 9, so we will not support it
310-
# here. Actually using loops to do that is kind of meaningless since there will be performance
311-
# issue there for sure.
312-
raise NotImplementedError("ReverseSequence is not supported to convert in OPSET 9,"
313-
" if possible please try using OPSET 8 instead.")
314-
315-
316242
@tf_op("Range")
317243
class Range:
318244
@classmethod

0 commit comments

Comments
 (0)