Skip to content

Commit 4eda997

Browse files
reverse sequence
1 parent 757ed9a commit 4eda997

File tree

3 files changed

+89
-18
lines changed

3 files changed

+89
-18
lines changed

tests/run_pretrained_models.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,10 @@ def run_tensorflow(self, sess, inputs):
155155
self.tf_runtime = time.time() - start
156156
return result
157157

158-
def to_onnx(self, tf_graph, opset=None, shape_override=None, input_names=None):
158+
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None):
159159
"""Convert graph to tensorflow."""
160160
return process_tf_graph(tf_graph, continue_on_error=False, verbose=True, opset=opset,
161-
target=Test.target, shape_override=shape_override,
161+
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
162162
input_names=input_names, output_names=self.output_names)
163163

164164
def run_caffe2(self, name, model_proto, inputs):
@@ -207,7 +207,8 @@ def create_onnx_file(name, model_proto, inputs, outdir):
207207
utils.save_protobuf(model_path, model_proto)
208208
print("\tcreated", model_path)
209209

210-
def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=None, perf=None, fold_const=None):
210+
def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=None, extra_opset=None,
211+
perf=None, fold_const=None):
211212
"""Run complete test against backend."""
212213
print(name)
213214
self.perf = perf
@@ -267,8 +268,8 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
267268
model_proto = None
268269
try:
269270
# convert model to onnx
270-
onnx_graph = self.to_onnx(sess.graph, opset=opset, shape_override=shape_override,
271-
input_names=inputs.keys())
271+
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
272+
shape_override=shape_override, input_names=inputs.keys())
272273
model_proto = onnx_graph.make_model("converted from tf2onnx")
273274
new_model_proto = optimizer.optimize_graph(onnx_graph, debug=debug).make_model("optimized")
274275
if new_model_proto:
@@ -328,6 +329,8 @@ def get_args():
328329
choices=["caffe2", "onnxmsrtnext", "onnxruntime"], help="backend to use")
329330
parser.add_argument("--verbose", help="verbose output", action="store_true")
330331
parser.add_argument("--opset", type=int, default=None, help="opset to use")
332+
parser.add_argument("--extra_opset", default=None,
333+
help="extra opset with format like domain:version, e.g. com.microsoft:1")
331334
parser.add_argument("--debug", help="debug vlog", action="store_true")
332335
parser.add_argument("--list", help="list tests", action="store_true")
333336
parser.add_argument("--onnx-file", help="create onnx file in directory")
@@ -338,6 +341,11 @@ def get_args():
338341
args = parser.parse_args()
339342

340343
args.target = args.target.split(",")
344+
if args.extra_opset:
345+
tokens = args.extra_opset.split(':')
346+
if len(tokens) != 2:
347+
raise ValueError("invalid extra_opset argument")
348+
args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]
341349
return args
342350

343351

@@ -385,7 +393,8 @@ def main():
385393
count += 1
386394
try:
387395
ret = t.run_test(test, backend=args.backend, debug=args.debug, onnx_file=args.onnx_file,
388-
opset=args.opset, perf=args.perf, fold_const=args.fold_const)
396+
opset=args.opset, extra_opset=args.extra_opset, perf=args.perf,
397+
fold_const=args.fold_const)
389398
except Exception as ex:
390399
ret = None
391400
print(ex)

tests/test_backend.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,17 +1519,19 @@ def test_erf(self):
15191519
_ = tf.identity(x_, name=_TFOUTPUT)
15201520
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=0.01)
15211521

1522-
@check_opset_min_version(8, "Scan")
1523-
@skip_opset(9, "ReverseSequence")
1524-
def test_reverse_sequence_batch_major(self):
1522+
def _test_reverse_sequence_batch_major(self, extra_opset=None):
1523+
process_args = {}
1524+
if extra_opset is not None:
1525+
process_args["extra_opset"] = [extra_opset]
1526+
15251527
x_val = np.array([[[1, 2, 3], [4, 5, 6], [0, 0, 0]],
15261528
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
15271529
[[1, 2, 3], [0, 0, 0], [0, 0, 0]]],
15281530
dtype=np.float32)
15291531
x = tf.placeholder(tf.float32, [None, 3, 3], name=_TFINPUT)
15301532
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[2, 3, 1])
15311533
_ = tf.identity(x_, name=_TFOUTPUT)
1532-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1534+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
15331535
tf.reset_default_graph()
15341536

15351537
x_val = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3],
@@ -1540,19 +1542,21 @@ def test_reverse_sequence_batch_major(self):
15401542
x = tf.placeholder(tf.float32, [None, 3], name=_TFINPUT)
15411543
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[3] * 9)
15421544
_ = tf.identity(x_, name=_TFOUTPUT)
1543-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1545+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
15441546
tf.reset_default_graph()
15451547

15461548
x_val_shape = [5, 5, 7, 8, 9]
15471549
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
15481550
x = tf.placeholder(tf.float32, [None, 5, 7, 8, 9], name=_TFINPUT)
15491551
x_ = tf.reverse_sequence(x, seq_axis=1, batch_axis=0, seq_lengths=[5, 5, 5, 5, 5])
15501552
_ = tf.identity(x_, name=_TFOUTPUT)
1551-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1553+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1554+
1555+
def _test_reverse_sequence_time_major(self, extra_opset=None):
1556+
process_args = {}
1557+
if extra_opset is not None:
1558+
process_args["extra_opset"] = [extra_opset]
15521559

1553-
@check_opset_min_version(8, "Scan")
1554-
@skip_opset(9, "ReverseSequence")
1555-
def test_reverse_sequence_time_major(self):
15561560
x_val = np.array([[[1, 2, 3], [1, 2, 3], [1, 2, 3]],
15571561
[[4, 5, 6], [4, 5, 6], [0, 0, 0]],
15581562
[[0, 0, 0], [7, 8, 9], [0, 0, 0]]
@@ -1561,7 +1565,7 @@ def test_reverse_sequence_time_major(self):
15611565
x = tf.placeholder(tf.float32, [3, None, 3], name=_TFINPUT)
15621566
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[2, 3, 1])
15631567
_ = tf.identity(x_, name=_TFOUTPUT)
1564-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1568+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
15651569
tf.reset_default_graph()
15661570

15671571
x_val = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3],
@@ -1572,15 +1576,35 @@ def test_reverse_sequence_time_major(self):
15721576
x = tf.placeholder(tf.float32, [9, None], name=_TFINPUT)
15731577
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[9, 9, 9])
15741578
_ = tf.identity(x_, name=_TFOUTPUT)
1575-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1579+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
15761580
tf.reset_default_graph()
15771581

15781582
x_val_shape = [5, 5, 7, 8, 9]
15791583
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
15801584
x = tf.placeholder(tf.float32, [5, None, 7, 8, 9], name=_TFINPUT)
15811585
x_ = tf.reverse_sequence(x, seq_axis=0, batch_axis=1, seq_lengths=[5, 5, 5, 5, 5])
15821586
_ = tf.identity(x_, name=_TFOUTPUT)
1583-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1587+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, process_args=process_args)
1588+
1589+
@check_opset_min_version(8, "Scan")
1590+
@skip_opset(9, "ReverseSequence")
1591+
def test_reverse_sequence_batch_major(self):
1592+
self._test_reverse_sequence_batch_major()
1593+
1594+
@check_opset_min_version(8, "Scan")
1595+
@skip_opset(9, "ReverseSequence")
1596+
def test_reverse_sequence_time_major(self):
1597+
self._test_reverse_sequence_time_major()
1598+
1599+
@test_ms_domain()
1600+
@unittest.skipIf(True, "not support in pypi onnxruntime")
1601+
def test_ms_reverse_sequence_batch_major(self, extra_opset):
1602+
self._test_reverse_sequence_batch_major(extra_opset)
1603+
1604+
@test_ms_domain()
1605+
@unittest.skipIf(True, "not support in pypi onnxruntime")
1606+
def test_ms_reverse_sequence_time_major(self, extra_opset):
1607+
self._test_reverse_sequence_time_major(extra_opset)
15841608

15851609
@check_opset_min_version(8, "where")
15861610
def test_where(self):

tf2onnx/custom_opsets/ms.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,41 @@ def version_1(cls, ctx, node, **kwargs):
3636
utils.make_sure(dtype is not None, "Tidx of %s is None", node.name)
3737
ctx.remove_node(node.name)
3838
make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], node.name, shape, dtype)
39+
40+
41+
@tf_op("ReverseSequence", domain=constants.MICROSOFT_DOMAIN)
42+
class ReverseSequence:
43+
@classmethod
44+
def version_1(cls, ctx, node, **kwargs):
45+
"""ReverseSequence"""
46+
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
47+
# T output = ReverseSequence(T input, int32 seqence_lens, @int time_axis, @int batch_axis)
48+
seq_dim = node.get_attr("seq_dim")
49+
utils.make_sure(seq_dim is not None, "sequence dim must be given in {}".format(node.name))
50+
seq_dim = seq_dim.i
51+
batch_dim = node.get_attr("batch_dim")
52+
if batch_dim is not None:
53+
batch_dim = batch_dim.i
54+
else:
55+
batch_dim = 0
56+
57+
output_dtypes = node.output_dtypes
58+
output_shapes = node.output_shapes
59+
ctx.remove_node(node.name)
60+
node = ctx.make_node(
61+
"ReverseSequence",
62+
node.input,
63+
outputs=node.output,
64+
shapes=output_shapes,
65+
dtypes=output_dtypes,
66+
domain=constants.MICROSOFT_DOMAIN,
67+
attr={"time_axis": seq_dim, "batch_axis": batch_dim}
68+
)
69+
70+
seq_len_dtype = ctx.get_dtype(node.input[1])
71+
utils.make_sure(seq_len_dtype is not None, "dtype of {} is None".format(node.input[1]))
72+
target_dtype = TensorProto.INT32
73+
if seq_len_dtype != target_dtype:
74+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=target_dtype)
75+
ctx.copy_shape(cast_node.input[0], cast_node.output[0])
76+
ctx.set_dtype(cast_node.output[0], target_dtype)

0 commit comments

Comments
 (0)