Skip to content

Commit c11a5d1

Browse files
committed
opset 9 scan support
1 parent be6b986 commit c11a5d1

File tree

5 files changed

+65
-29
lines changed

5 files changed

+65
-29
lines changed

tests/common.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
__all__ = ["TestConfig", "get_test_config", "unittest_main",
1717
"check_tf_min_version", "skip_tf_versions",
18-
"check_opset_min_version", "check_target", "skip_onnxruntime_backend", "skip_caffe2_backend",
19-
"check_onnxruntime_incompatibility", "validate_const_node", "group_nodes_by_type"]
18+
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
19+
"skip_specific_opset_version", "check_onnxruntime_incompatibility", "validate_const_node",
20+
"group_nodes_by_type"]
2021

2122

2223
# pylint: disable=missing-docstring
@@ -155,6 +156,13 @@ def check_opset_min_version(min_required_version, message=""):
155156
return unittest.skipIf(config.opset < min_required_version, reason)
156157

157158

159+
def skip_specific_opset_version(opset_v, message=""):
160+
""" Skip if opset = opset_v """
161+
config = get_test_config()
162+
reason = _append_message("conversion requires opset != {}".format(opset_v), message)
163+
return unittest.skipIf(config.opset == opset_v, reason)
164+
165+
158166
def check_target(required_target, message=""):
159167
""" Skip if required_target is NOT specified """
160168
config = get_test_config()

tests/test_backend.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,8 +1398,8 @@ def test_erf(self):
13981398
_ = tf.identity(x_, name=_TFOUTPUT)
13991399
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=0.01)
14001400

1401-
# @unittest.skipIf(OPSET < 8, "supported with opset 8 or better")
1402-
@unittest.skip("FIXME: the newest onnxruntime wheel hasn't been published to PYPI, so scan op is not supported")
1401+
@check_opset_min_version(8, "Scan")
1402+
@skip_specific_opset_version(9, "ReverseSequence not supported")
14031403
def test_reverse_sequence_batch_major(self):
14041404
x_val = np.array([[[1, 2, 3], [4, 5, 6], [0, 0, 0]],
14051405
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
@@ -1429,8 +1429,8 @@ def test_reverse_sequence_batch_major(self):
14291429
_ = tf.identity(x_, name=_TFOUTPUT)
14301430
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14311431

1432-
# @unittest.skipIf(OPSET < 8, "supported with opset 8 or better")
1433-
@unittest.skip("FIXME: the newest onnxruntime wheel hasn't been published to PYPI, so scan op is not supported")
1432+
@check_opset_min_version(8, "Scan")
1433+
@skip_specific_opset_version(9, "ReverseSequence not supported")
14341434
def test_reverse_sequence_time_major(self):
14351435
x_val = np.array([[[1, 2, 3], [1, 2, 3], [1, 2, 3]],
14361436
[[4, 5, 6], [4, 5, 6], [0, 0, 0]],
@@ -1461,8 +1461,7 @@ def test_reverse_sequence_time_major(self):
14611461
_ = tf.identity(x_, name=_TFOUTPUT)
14621462
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14631463

1464-
# @unittest.skipIf(OPSET < 8, "supported with opset 8 or better")
1465-
@unittest.skip("FIXME: the newest onnxruntime wheel hasn't been published to PYPI, so Select op is not supported")
1464+
@check_opset_min_version(8, "where")
14661465
def test_where(self):
14671466
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.int32)
14681467
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],

tests/test_custom_rnncell.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tensorflow.contrib import rnn
1414
from tensorflow.python.ops import init_ops
1515
from backend_test_base import Tf2OnnxBackendTestBase
16-
from common import check_tf_min_version, check_opset_min_version, unittest_main
16+
from common import check_tf_min_version, check_opset_min_version, unittest_main, skip_specific_opset_version
1717

1818

1919
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -284,6 +284,7 @@ def test_multi_rnn_lstm(self, state_is_tuple=True):
284284

285285
@check_opset_min_version(8, "Scan")
286286
@check_tf_min_version("1.8")
287+
@skip_specific_opset_version(9, "ReverseSequence cannot be efficient mapped in OPSET 9")
287288
def test_bidrectional_attention_wrapper_lstm_encoder(self):
288289
size = 30
289290
time_step = 3

tf2onnx/rewriter/custom_rnn_rewriter.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,19 @@ def rewrite(self, context):
105105

106106
state_inputs_initial_values = []
107107
for state_input in scan_props.state_inputs_initial_values:
108-
nodes = self._adapt_scan_sequence_input_or_output("input", state_input, False)
109-
state_inputs_initial_values.append(nodes[-1].output[0])
108+
if self.g.opset == 8:
109+
nodes = self._adapt_scan_sequence_input_or_output("input", state_input, False)
110+
state_inputs_initial_values.append(nodes[-1].output[0])
111+
else:
112+
state_inputs_initial_values.append(state_input)
110113

111114
scan_inputs_initial_values = []
112115
for scan_input in scan_props.scan_inputs_initial_values:
113-
nodes = self._adapt_scan_sequence_input_or_output("input", scan_input, False)
114-
scan_inputs_initial_values.append(nodes[-1].output[0])
116+
if self.g.opset == 8:
117+
nodes = self._adapt_scan_sequence_input_or_output("input", scan_input, False)
118+
scan_inputs_initial_values.append(nodes[-1].output[0])
119+
else:
120+
scan_inputs_initial_values.append(scan_input)
115121

116122
cell_g_info = context.cell_graph
117123
scan_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, cell_g_info.nodes, cell_g_info.outputs)
@@ -155,17 +161,24 @@ def _create_scan_node(self, context, scan_props, init_values):
155161
n = self.g.get_node_by_output(tensor_value_info.id)
156162
self.g.remove_node(n.name)
157163
else:
158-
loop_outputs_shapes.append(None)
164+
loop_outputs_shapes.append([-1])
159165
loop_outputs_dtypes.append(None)
160166

161-
# here we did not give the sequence_length, because
162-
# current batch size is 1, not original batch size
163-
# original seq_length will be used by the loop body of Scan op.
164-
scan_node = self.g.make_node("Scan", [""] + init_values, op_name_scope="custom_rnn_scan",
165-
attr={"num_scan_inputs": len(scan_props.scan_inputs)},
166-
output_count=len(scan_props.state_outputs + scan_props.scan_outputs),
167-
shapes=loop_outputs_shapes, dtypes=loop_outputs_dtypes,
168-
skip_conversion=False)
167+
if self.g.opset == 8:
168+
# here we did not give the sequence_length, because
169+
# current batch size is 1, not original batch size
170+
# original seq_length will be used by the loop body of Scan op.
171+
scan_node = self.g.make_node("Scan", [""] + init_values, op_name_scope="custom_rnn_scan",
172+
attr={"num_scan_inputs": len(scan_props.scan_inputs)},
173+
output_count=len(scan_props.state_outputs + scan_props.scan_outputs),
174+
shapes=loop_outputs_shapes, dtypes=loop_outputs_dtypes,
175+
skip_conversion=False)
176+
else:
177+
scan_node = self.g.make_node("Scan", init_values, op_name_scope="custom_rnn_scan",
178+
attr={"num_scan_inputs": len(scan_props.scan_inputs)},
179+
output_count=len(scan_props.state_outputs + scan_props.scan_outputs),
180+
shapes=loop_outputs_shapes, dtypes=loop_outputs_dtypes,
181+
skip_conversion=False)
169182

170183
return scan_node
171184

@@ -175,17 +188,22 @@ def _connect_scan_with_output(self, context, scan_node):
175188
index = 0
176189
for out_tensor_value_info in context.loop_properties.state_outputs_exits:
177190
if out_tensor_value_info.id:
178-
nodes = self._adapt_scan_sequence_input_or_output("state_output_reshape",
179-
scan_node.output[index], True)
180-
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
181-
191+
if self.g.opset == 8:
192+
nodes = self._adapt_scan_sequence_input_or_output("state_output_reshape",
193+
scan_node.output[index], True)
194+
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
195+
else:
196+
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
182197
index += 1
183198

184199
for out_tensor_value_info in context.loop_properties.scan_outputs_exits:
185200
if out_tensor_value_info.id:
186-
nodes = self._adapt_scan_sequence_input_or_output("scan_output_reshape",
187-
scan_node.output[index], True)
188-
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
201+
if self.g.opset == 8:
202+
nodes = self._adapt_scan_sequence_input_or_output("scan_output_reshape",
203+
scan_node.output[index], True)
204+
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
205+
else:
206+
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
189207
index += 1
190208

191209

tf2onnx/tfonnx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,15 @@ def reverse_op8(ctx, node, name, args):
14541454
node.input[0] = node.input[1]
14551455
node.input[1] = tmp
14561456

1457+
def reverse_op9(ctx, node, name, args):
1458+
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
1459+
# we cannot easily construct reverse_sequence equivalence in opset 9, so we will not support it
1460+
# here. Actually using loops to do that is kind of meaningless since there will be performance
1461+
# issue there for sure.
1462+
1463+
raise RuntimeError("ReverseSequence is not supported to convert in OPSET9,"
1464+
" if possible please try use OPSET 8 instead.")
1465+
14571466

14581467
def shape_op(ctx, node, name, args):
14591468
# out_type output = Shape(T input, @int32|int64 out_type), out_type by default int32
@@ -1817,6 +1826,7 @@ def where_op(ctx, node, name, args):
18171826
"Less": (logical_compare_op, []),
18181827
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18191828
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
1829+
"ReverseSequence": (reverse_op9, []),
18201830
"Sign": (sign_op9, []),
18211831
"Sinh": (direct_op, []),
18221832
"Where": (where_op, []),

0 commit comments

Comments
 (0)