Skip to content

Commit aa6f345

Browse files
committed
fix review comments
1 parent da765aa commit aa6f345

File tree

5 files changed

+16
-12
lines changed

5 files changed

+16
-12
lines changed

tests/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
__all__ = ["TestConfig", "get_test_config", "unittest_main",
1717
"check_tf_min_version", "skip_tf_versions",
1818
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
19-
"skip_specific_opset_version", "check_onnxruntime_incompatibility", "validate_const_node",
19+
"skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
2020
"group_nodes_by_type"]
2121

2222

@@ -156,7 +156,7 @@ def check_opset_min_version(min_required_version, message=""):
156156
return unittest.skipIf(config.opset < min_required_version, reason)
157157

158158

159-
def skip_specific_opset_version(opset_v, message=""):
159+
def skip_opset(opset_v, message=""):
160160
""" Skip if opset = opset_v """
161161
config = get_test_config()
162162
reason = _append_message("conversion requires opset != {}".format(opset_v), message)

tests/test_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,7 +1399,7 @@ def test_erf(self):
13991399
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=0.01)
14001400

14011401
@check_opset_min_version(8, "Scan")
1402-
@skip_specific_opset_version(9, "ReverseSequence not supported")
1402+
@skip_opset(9, "ReverseSequence")
14031403
def test_reverse_sequence_batch_major(self):
14041404
x_val = np.array([[[1, 2, 3], [4, 5, 6], [0, 0, 0]],
14051405
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
@@ -1430,7 +1430,7 @@ def test_reverse_sequence_batch_major(self):
14301430
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14311431

14321432
@check_opset_min_version(8, "Scan")
1433-
@skip_specific_opset_version(9, "ReverseSequence not supported")
1433+
@skip_opset(9, "ReverseSequence")
14341434
def test_reverse_sequence_time_major(self):
14351435
x_val = np.array([[[1, 2, 3], [1, 2, 3], [1, 2, 3]],
14361436
[[4, 5, 6], [4, 5, 6], [0, 0, 0]],

tests/test_custom_rnncell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tensorflow.contrib import rnn
1414
from tensorflow.python.ops import init_ops
1515
from backend_test_base import Tf2OnnxBackendTestBase
16-
from common import check_tf_min_version, check_opset_min_version, unittest_main, skip_specific_opset_version
16+
from common import check_tf_min_version, check_opset_min_version, unittest_main, skip_opset
1717

1818

1919
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -284,7 +284,7 @@ def test_multi_rnn_lstm(self, state_is_tuple=True):
284284

285285
@check_opset_min_version(8, "Scan")
286286
@check_tf_min_version("1.8")
287-
@skip_specific_opset_version(9, "ReverseSequence cannot be efficient mapped in OPSET 9")
287+
@skip_opset(9, "ReverseSequence")
288288
def test_bidrectional_attention_wrapper_lstm_encoder(self):
289289
size = 30
290290
time_step = 3

tf2onnx/rewriter/custom_rnn_rewriter.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def _parse_rnn_loop(self, context):
8585
return True
8686

8787
def need_rewrite(self, context):
88+
if self.g.opset < 8:
89+
log.debug("skip the custom_rnn_rewriter due to lower opset version %s", self.g.opset)
90+
return False
91+
8892
context.rnn_scope = self._get_rnn_scope_name(context.while_context_scope)
8993

9094
if not self._parse_rnn_loop(context):
@@ -108,15 +112,15 @@ def rewrite(self, context):
108112
if self.g.opset == 8:
109113
nodes = self._adapt_scan_sequence_input_or_output("input", state_input, False)
110114
state_inputs_initial_values.append(nodes[-1].output[0])
111-
else:
115+
else: # since opset 9
112116
state_inputs_initial_values.append(state_input)
113117

114118
scan_inputs_initial_values = []
115119
for scan_input in scan_props.scan_inputs_initial_values:
116120
if self.g.opset == 8:
117121
nodes = self._adapt_scan_sequence_input_or_output("input", scan_input, False)
118122
scan_inputs_initial_values.append(nodes[-1].output[0])
119-
else:
123+
else: # since opset 9
120124
scan_inputs_initial_values.append(scan_input)
121125

122126
cell_g_info = context.cell_graph
@@ -192,7 +196,7 @@ def _connect_scan_with_output(self, context, scan_node):
192196
nodes = self._adapt_scan_sequence_input_or_output("state_output_reshape",
193197
scan_node.output[index], True)
194198
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
195-
else:
199+
else: # since opset 9
196200
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
197201
index += 1
198202

@@ -202,7 +206,7 @@ def _connect_scan_with_output(self, context, scan_node):
202206
nodes = self._adapt_scan_sequence_input_or_output("scan_output_reshape",
203207
scan_node.output[index], True)
204208
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
205-
else:
209+
else: # since opset 9
206210
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
207211
index += 1
208212

tf2onnx/tfonnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,8 +1460,8 @@ def reverse_op9(ctx, node, name, args):
14601460
# here. Actually using loops to do that is kind of meaningless since there will be performance
14611461
# issue there for sure.
14621462

1463-
raise RuntimeError("ReverseSequence is not supported to convert in OPSET9,"
1464-
" if possible please try use OPSET 8 instead.")
1463+
raise NotImplementedError("ReverseSequence is not supported to convert in OPSET 9,"
1464+
" if possible please try using OPSET 8 instead.")
14651465

14661466

14671467
def shape_op(ctx, node, name, args):

0 commit comments

Comments
 (0)