Skip to content

Commit 5243934

Browse files
authored
Merge pull request #442 from lucienwang1009/reverse_sequence
reverse sequence
2 parents 4ba449b + 08370e6 commit 5243934

File tree

4 files changed

+108
-20
lines changed

4 files changed

+108
-20
lines changed

tests/common.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from parameterized import parameterized
1414
from tf2onnx import constants, utils
1515

16-
__all__ = ["TestConfig", "get_test_config", "unittest_main",
17-
"check_tf_min_version", "skip_tf_versions",
16+
__all__ = ["TestConfig", "get_test_config", "unittest_main", "check_onnxruntime_backend",
17+
"check_tf_min_version", "skip_tf_versions", "check_onnxruntime_min_version",
1818
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
1919
"skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
2020
"group_nodes_by_type", "test_ms_domain", "check_node_domain"]
@@ -177,6 +177,21 @@ def skip_onnxruntime_backend(message=""):
177177
return unittest.skipIf(config.is_onnxruntime_backend, reason)
178178

179179

180+
def check_onnxruntime_backend(message=""):
181+
""" Skip if backend is NOT onnxruntime """
182+
config = get_test_config()
183+
reason = _append_message("only supported by onnxruntime", message)
184+
return unittest.skipIf(not config.is_onnxruntime_backend, reason)
185+
186+
187+
def check_onnxruntime_min_version(min_required_version, message=""):
188+
""" Skip if onnxruntime version < min_required_version """
189+
config = get_test_config()
190+
reason = _append_message("conversion requires onnxruntime >= {}".format(min_required_version), message)
191+
return unittest.skipIf(config.is_onnxruntime_backend and
192+
config.backend_version < LooseVersion(min_required_version), reason)
193+
194+
180195
def skip_caffe2_backend(message=""):
181196
""" Skip if backend is caffe2 """
182197
config = get_test_config()
@@ -253,6 +268,7 @@ def check_gru_count(graph, expected_count):
253268
def test_ms_domain(versions=None):
254269
""" Parameterize test case to apply ms opset(s) as extra_opset. """
255270

271+
@check_onnxruntime_backend()
256272
def _custom_name_func(testcase_func, param_num, param):
257273
del param_num
258274
arg = param.args[0]

tests/run_pretrained_models.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,10 @@ def run_tensorflow(self, sess, inputs):
155155
self.tf_runtime = time.time() - start
156156
return result
157157

158-
def to_onnx(self, tf_graph, opset=None, shape_override=None, input_names=None):
158+
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None):
159159
"""Convert graph to tensorflow."""
160160
return process_tf_graph(tf_graph, continue_on_error=False, verbose=True, opset=opset,
161-
target=Test.target, shape_override=shape_override,
161+
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
162162
input_names=input_names, output_names=self.output_names)
163163

164164
def run_caffe2(self, name, model_proto, inputs):
@@ -207,7 +207,8 @@ def create_onnx_file(name, model_proto, inputs, outdir):
207207
utils.save_protobuf(model_path, model_proto)
208208
print("\tcreated", model_path)
209209

210-
def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=None, perf=None, fold_const=None):
210+
def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=None, extra_opset=None,
211+
perf=None, fold_const=None):
211212
"""Run complete test against backend."""
212213
print(name)
213214
self.perf = perf
@@ -267,8 +268,8 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
267268
model_proto = None
268269
try:
269270
# convert model to onnx
270-
onnx_graph = self.to_onnx(sess.graph, opset=opset, shape_override=shape_override,
271-
input_names=inputs.keys())
271+
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
272+
shape_override=shape_override, input_names=inputs.keys())
272273
model_proto = onnx_graph.make_model("converted from tf2onnx")
273274
new_model_proto = optimizer.optimize_graph(onnx_graph, debug=debug).make_model("optimized")
274275
if new_model_proto:
@@ -328,6 +329,8 @@ def get_args():
328329
choices=["caffe2", "onnxmsrtnext", "onnxruntime"], help="backend to use")
329330
parser.add_argument("--verbose", help="verbose output", action="store_true")
330331
parser.add_argument("--opset", type=int, default=None, help="opset to use")
332+
parser.add_argument("--extra_opset", default=None,
333+
help="extra opset with format like domain:version, e.g. com.microsoft:1")
331334
parser.add_argument("--debug", help="debug vlog", action="store_true")
332335
parser.add_argument("--list", help="list tests", action="store_true")
333336
parser.add_argument("--onnx-file", help="create onnx file in directory")
@@ -338,6 +341,11 @@ def get_args():
338341
args = parser.parse_args()
339342

340343
args.target = args.target.split(",")
344+
if args.extra_opset:
345+
tokens = args.extra_opset.split(':')
346+
if len(tokens) != 2:
347+
raise ValueError("invalid extra_opset argument")
348+
args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]
341349
return args
342350

343351

@@ -385,7 +393,8 @@ def main():
385393
count += 1
386394
try:
387395
ret = t.run_test(test, backend=args.backend, debug=args.debug, onnx_file=args.onnx_file,
388-
opset=args.opset, perf=args.perf, fold_const=args.fold_const)
396+
opset=args.opset, extra_opset=args.extra_opset, perf=args.perf,
397+
fold_const=args.fold_const)
389398
except Exception as ex:
390399
ret = None
391400
print(ex)

tests/test_backend.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,17 +1539,19 @@ def test_erf(self):
15391539
_ = tf.identity(x_, name=_TFOUTPUT)
15401540
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=0.01)
15411541

1542-
@check_opset_min_version(8, "Scan")
1543-
@skip_opset(9, "ReverseSequence")
1544-
def test_reverse_sequence_batch_major(self):
1542+
def _test_reverse_sequence_batch_major(self, extra_opset=None):
1543+
process_args = {}
1544+
if extra_opset is not None:
1545+
process_args["extra_opset"] = [extra_opset]
1546+
15451547
x_val = np.array([[[1, 2, 3], [4, 5, 6], [0, 0, 0]],
15461548
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
15471549
[[1, 2, 3], [0, 0, 0], [0, 0, 0]]],
15481550
dtype=np.float32)
15491551
x = tf.placeholder(tf.float32, [None, 3, 3], name=_TFINPUT)
15501552
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[2, 3, 1])
15511553
_ = tf.identity(x_, name=_TFOUTPUT)
1552-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1554+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
15531555
tf.reset_default_graph()
15541556

15551557
x_val = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3],
@@ -1560,19 +1562,21 @@ def test_reverse_sequence_batch_major(self):
15601562
x = tf.placeholder(tf.float32, [None, 3], name=_TFINPUT)
15611563
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[3] * 9)
15621564
_ = tf.identity(x_, name=_TFOUTPUT)
1563-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1565+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
15641566
tf.reset_default_graph()
15651567

15661568
x_val_shape = [5, 5, 7, 8, 9]
15671569
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
15681570
x = tf.placeholder(tf.float32, [None, 5, 7, 8, 9], name=_TFINPUT)
15691571
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[5, 5, 5, 5, 5])
15701572
_ = tf.identity(x_, name=_TFOUTPUT)
1571-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1573+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1574+
1575+
def _test_reverse_sequence_time_major(self, extra_opset=None):
1576+
process_args = {}
1577+
if extra_opset is not None:
1578+
process_args["extra_opset"] = [extra_opset]
15721579

1573-
@check_opset_min_version(8, "Scan")
1574-
@skip_opset(9, "ReverseSequence")
1575-
def test_reverse_sequence_time_major(self):
15761580
x_val = np.array([[[1, 2, 3], [1, 2, 3], [1, 2, 3]],
15771581
[[4, 5, 6], [4, 5, 6], [0, 0, 0]],
15781582
[[0, 0, 0], [7, 8, 9], [0, 0, 0]]
@@ -1581,7 +1585,7 @@ def test_reverse_sequence_time_major(self):
15811585
x = tf.placeholder(tf.float32, [3, None, 3], name=_TFINPUT)
15821586
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[2, 3, 1])
15831587
_ = tf.identity(x_, name=_TFOUTPUT)
1584-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1588+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
15851589
tf.reset_default_graph()
15861590

15871591
x_val = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3],
@@ -1592,15 +1596,36 @@ def test_reverse_sequence_time_major(self):
15921596
x = tf.placeholder(tf.float32, [9, None], name=_TFINPUT)
15931597
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[9, 9, 9])
15941598
_ = tf.identity(x_, name=_TFOUTPUT)
1595-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1599+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
15961600
tf.reset_default_graph()
15971601

15981602
x_val_shape = [5, 5, 7, 8, 9]
15991603
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
16001604
x = tf.placeholder(tf.float32, [5, None, 7, 8, 9], name=_TFINPUT)
16011605
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[5, 5, 5, 5, 5])
16021606
_ = tf.identity(x_, name=_TFOUTPUT)
1603-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1607+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1608+
1609+
@check_opset_min_version(8, "Scan")
1610+
@skip_opset(9, "ReverseSequence")
1611+
def test_reverse_sequence_batch_major(self):
1612+
self._test_reverse_sequence_batch_major()
1613+
1614+
@check_opset_min_version(8, "Scan")
1615+
@skip_opset(9, "ReverseSequence")
1616+
def test_reverse_sequence_time_major(self):
1617+
self._test_reverse_sequence_time_major()
1618+
1619+
# only support onnxruntime with version larger than 0.4.0
1620+
@test_ms_domain()
1621+
@check_onnxruntime_min_version("0.4.0")
1622+
def test_ms_reverse_sequence_batch_major(self, extra_opset):
1623+
self._test_reverse_sequence_batch_major(extra_opset)
1624+
1625+
@test_ms_domain()
1626+
@check_onnxruntime_min_version("0.4.0")
1627+
def test_ms_reverse_sequence_time_major(self, extra_opset):
1628+
self._test_reverse_sequence_time_major(extra_opset)
16041629

16051630
@check_opset_min_version(8, "where")
16061631
def test_where(self):

tf2onnx/custom_opsets/ms.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,41 @@ def version_1(cls, ctx, node, **kwargs):
3636
utils.make_sure(dtype is not None, "Tidx of %s is None", node.name)
3737
ctx.remove_node(node.name)
3838
make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], node.name, shape, dtype)
39+
40+
41+
@tf_op("ReverseSequence", domain=constants.MICROSOFT_DOMAIN)
42+
class ReverseSequence:
43+
@classmethod
44+
def version_1(cls, ctx, node, **kwargs):
45+
"""ReverseSequence"""
46+
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
47+
# T output = ReverseSequence(T input, int32 seqence_lens, @int time_axis, @int batch_axis)
48+
seq_dim = node.get_attr("seq_dim")
49+
utils.make_sure(seq_dim is not None, "sequence dim must be given in {}".format(node.name))
50+
seq_dim = seq_dim.i
51+
batch_dim = node.get_attr("batch_dim")
52+
if batch_dim is not None:
53+
batch_dim = batch_dim.i
54+
else:
55+
batch_dim = 0
56+
57+
output_dtypes = node.output_dtypes
58+
output_shapes = node.output_shapes
59+
ctx.remove_node(node.name)
60+
node = ctx.make_node(
61+
"ReverseSequence",
62+
node.input,
63+
outputs=node.output,
64+
shapes=output_shapes,
65+
dtypes=output_dtypes,
66+
domain=constants.MICROSOFT_DOMAIN,
67+
attr={"time_axis": seq_dim, "batch_axis": batch_dim}
68+
)
69+
70+
seq_len_dtype = ctx.get_dtype(node.input[1])
71+
utils.make_sure(seq_len_dtype is not None, "dtype of {} is None".format(node.input[1]))
72+
target_dtype = TensorProto.INT32
73+
if seq_len_dtype != target_dtype:
74+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=target_dtype)
75+
ctx.copy_shape(cast_node.input[0], cast_node.output[0])
76+
ctx.set_dtype(cast_node.output[0], target_dtype)

0 commit comments

Comments
 (0)