Skip to content

Commit 5feb8e8

Browse files
authored
Merge pull request #384 from pengwa/opset_9_scan
opset 9 scan support
2 parents dba60d5 + 0918351 commit 5feb8e8

File tree

9 files changed

+72
-40
lines changed

9 files changed

+72
-40
lines changed

tests/common.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
__all__ = ["TestConfig", "get_test_config", "unittest_main",
1717
"check_tf_min_version", "skip_tf_versions",
18-
"check_opset_min_version", "check_target", "skip_onnxruntime_backend", "skip_caffe2_backend",
19-
"check_onnxruntime_incompatibility", "validate_const_node", "group_nodes_by_type"]
18+
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
19+
"skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
20+
"group_nodes_by_type"]
2021

2122

2223
# pylint: disable=missing-docstring
@@ -155,6 +156,13 @@ def check_opset_min_version(min_required_version, message=""):
155156
return unittest.skipIf(config.opset < min_required_version, reason)
156157

157158

159+
def skip_opset(opset_v, message=""):
160+
""" Skip if opset = opset_v """
161+
config = get_test_config()
162+
reason = _append_message("conversion requires opset != {}".format(opset_v), message)
163+
return unittest.skipIf(config.opset == opset_v, reason)
164+
165+
158166
def check_target(required_target, message=""):
159167
""" Skip if required_target is NOT specified """
160168
config = get_test_config()

tests/test_backend.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,8 +1402,8 @@ def test_erf(self):
14021402
_ = tf.identity(x_, name=_TFOUTPUT)
14031403
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=0.01)
14041404

1405-
# @unittest.skipIf(OPSET < 8, "supported with opset 8 or better")
1406-
@unittest.skip("FIXME: the newest onnxruntime wheel hasn't been published to PYPI, so scan op is not supported")
1405+
@check_opset_min_version(8, "Scan")
1406+
@skip_opset(9, "ReverseSequence")
14071407
def test_reverse_sequence_batch_major(self):
14081408
x_val = np.array([[[1, 2, 3], [4, 5, 6], [0, 0, 0]],
14091409
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
@@ -1433,8 +1433,8 @@ def test_reverse_sequence_batch_major(self):
14331433
_ = tf.identity(x_, name=_TFOUTPUT)
14341434
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14351435

1436-
# @unittest.skipIf(OPSET < 8, "supported with opset 8 or better")
1437-
@unittest.skip("FIXME: the newest onnxruntime wheel hasn't been published to PYPI, so scan op is not supported")
1436+
@check_opset_min_version(8, "Scan")
1437+
@skip_opset(9, "ReverseSequence")
14381438
def test_reverse_sequence_time_major(self):
14391439
x_val = np.array([[[1, 2, 3], [1, 2, 3], [1, 2, 3]],
14401440
[[4, 5, 6], [4, 5, 6], [0, 0, 0]],
@@ -1465,8 +1465,7 @@ def test_reverse_sequence_time_major(self):
14651465
_ = tf.identity(x_, name=_TFOUTPUT)
14661466
self._run_test_case([_OUTPUT], {_INPUT: x_val})
14671467

1468-
# @unittest.skipIf(OPSET < 8, "supported with opset 8 or better")
1469-
@unittest.skip("FIXME: the newest onnxruntime wheel hasn't been published to PYPI, so Select op is not supported")
1468+
@check_opset_min_version(8, "where")
14701469
def test_where(self):
14711470
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.int32)
14721471
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],

tests/test_custom_rnncell.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tensorflow.contrib import rnn
1414
from tensorflow.python.ops import init_ops
1515
from backend_test_base import Tf2OnnxBackendTestBase
16-
from common import check_tf_min_version, check_opset_min_version, unittest_main
16+
from common import check_tf_min_version, check_opset_min_version, unittest_main, skip_opset
1717

1818

1919
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -284,6 +284,7 @@ def test_multi_rnn_lstm(self, state_is_tuple=True):
284284

285285
@check_opset_min_version(8, "Scan")
286286
@check_tf_min_version("1.8")
287+
@skip_opset(9, "ReverseSequence")
287288
def test_bidrectional_attention_wrapper_lstm_encoder(self):
288289
size = 30
289290
time_step = 3

tf2onnx/graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,9 @@ def replace_all_inputs(ops, old_input, new_input):
942942
return
943943

944944
for node in ops:
945+
if old_input in node.input and new_input in node.output:
946+
raise RuntimeError("creating a circle in the graph is not allowed: " + node.name)
947+
945948
for i, input_name in enumerate(node.input):
946949
if input_name == old_input:
947950
node.input[i] = new_input

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ def __init__(self, graph, output_names, debug=False):
4545
self._handler_map = {}
4646
self._force_stop = {}
4747

48-
# make sure all proto of nodes or attribtues are update to date
49-
self._g.update_proto()
5048
self._initialize_handlers()
5149
self.pre_optimize_action()
5250

@@ -74,9 +72,6 @@ def pre_optimize_action(self):
7472
if name == output_name:
7573
child.input[i] = const_name
7674
self._g.make_const(const_name, new_data)
77-
78-
# need call this to make input update synced to protobuf val
79-
self._g.update_proto()
8075
self._g.remove_node(reshape_op.name)
8176
self._g.topological_sort(self._g.get_nodes())
8277

@@ -111,7 +106,6 @@ def post_optimize_action(self):
111106
name=op_name)
112107
else:
113108
self._remove_useless_tranpose(op)
114-
self._g.update_proto()
115109
self._g.topological_sort(self._g.get_nodes())
116110

117111
def merge_duplicated_transposes(self):

tf2onnx/rewriter/bigru_rewriter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def process_bigru(g, bi_grus):
116116
raise ValueError(
117117
"Reverse is still used by GRU as input, cannot remove")
118118

119-
g.update_proto()
120119
return g.get_nodes()
121120

122121

tf2onnx/rewriter/bilstm_rewriter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ def process_bilstm(g, bi_lstms):
101101
else:
102102
raise ValueError("Reverse is still used by LSTM as input, cannot remove")
103103

104-
105-
g.update_proto()
106104
return g.get_nodes()
107105

108106

tf2onnx/rewriter/custom_rnn_rewriter.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def _parse_rnn_loop(self, context):
8585
return True
8686

8787
def need_rewrite(self, context):
88+
if self.g.opset < 8:
89+
log.debug("skip the custom_rnn_rewriter due to lower opset version %s", self.g.opset)
90+
return False
91+
8892
context.rnn_scope = self._get_rnn_scope_name(context.while_context_scope)
8993

9094
if not self._parse_rnn_loop(context):
@@ -105,13 +109,19 @@ def rewrite(self, context):
105109

106110
state_inputs_initial_values = []
107111
for state_input in scan_props.state_inputs_initial_values:
108-
nodes = self._adapt_scan_sequence_input_or_output("input", state_input, False)
109-
state_inputs_initial_values.append(nodes[-1].output[0])
112+
if self.g.opset == 8:
113+
nodes = self._adapt_scan_sequence_input_or_output("input", state_input, False)
114+
state_inputs_initial_values.append(nodes[-1].output[0])
115+
else: # since opset 9
116+
state_inputs_initial_values.append(state_input)
110117

111118
scan_inputs_initial_values = []
112119
for scan_input in scan_props.scan_inputs_initial_values:
113-
nodes = self._adapt_scan_sequence_input_or_output("input", scan_input, False)
114-
scan_inputs_initial_values.append(nodes[-1].output[0])
120+
if self.g.opset == 8:
121+
nodes = self._adapt_scan_sequence_input_or_output("input", scan_input, False)
122+
scan_inputs_initial_values.append(nodes[-1].output[0])
123+
else: # since opset 9
124+
scan_inputs_initial_values.append(scan_input)
115125

116126
cell_g_info = context.cell_graph
117127
scan_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, cell_g_info.nodes, cell_g_info.outputs)
@@ -155,17 +165,24 @@ def _create_scan_node(self, context, scan_props, init_values):
155165
n = self.g.get_node_by_output(tensor_value_info.id)
156166
self.g.remove_node(n.name)
157167
else:
158-
loop_outputs_shapes.append(None)
168+
loop_outputs_shapes.append([-1])
159169
loop_outputs_dtypes.append(None)
160170

161-
# here we did not give the sequence_length, because
162-
# current batch size is 1, not original batch size
163-
# original seq_length will be used by the loop body of Scan op.
164-
scan_node = self.g.make_node("Scan", [""] + init_values, op_name_scope="custom_rnn_scan",
165-
attr={"num_scan_inputs": len(scan_props.scan_inputs)},
166-
output_count=len(scan_props.state_outputs + scan_props.scan_outputs),
167-
shapes=loop_outputs_shapes, dtypes=loop_outputs_dtypes,
168-
skip_conversion=False)
171+
if self.g.opset == 8:
172+
# here we did not give the sequence_length, because
173+
# current batch size is 1, not original batch size
174+
# original seq_length will be used by the loop body of Scan op.
175+
scan_node = self.g.make_node("Scan", [""] + init_values, op_name_scope="custom_rnn_scan",
176+
attr={"num_scan_inputs": len(scan_props.scan_inputs)},
177+
output_count=len(scan_props.state_outputs + scan_props.scan_outputs),
178+
shapes=loop_outputs_shapes, dtypes=loop_outputs_dtypes,
179+
skip_conversion=False)
180+
else:
181+
scan_node = self.g.make_node("Scan", init_values, op_name_scope="custom_rnn_scan",
182+
attr={"num_scan_inputs": len(scan_props.scan_inputs)},
183+
output_count=len(scan_props.state_outputs + scan_props.scan_outputs),
184+
shapes=loop_outputs_shapes, dtypes=loop_outputs_dtypes,
185+
skip_conversion=False)
169186

170187
return scan_node
171188

@@ -175,17 +192,22 @@ def _connect_scan_with_output(self, context, scan_node):
175192
index = 0
176193
for out_tensor_value_info in context.loop_properties.state_outputs_exits:
177194
if out_tensor_value_info.id:
178-
nodes = self._adapt_scan_sequence_input_or_output("state_output_reshape",
179-
scan_node.output[index], True)
180-
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
181-
195+
if self.g.opset == 8:
196+
nodes = self._adapt_scan_sequence_input_or_output("state_output_reshape",
197+
scan_node.output[index], True)
198+
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
199+
else: # since opset 9
200+
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
182201
index += 1
183202

184203
for out_tensor_value_info in context.loop_properties.scan_outputs_exits:
185204
if out_tensor_value_info.id:
186-
nodes = self._adapt_scan_sequence_input_or_output("scan_output_reshape",
187-
scan_node.output[index], True)
188-
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
205+
if self.g.opset == 8:
206+
nodes = self._adapt_scan_sequence_input_or_output("scan_output_reshape",
207+
scan_node.output[index], True)
208+
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
209+
else: # since opset 9
210+
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
189211
index += 1
190212

191213

tf2onnx/tfonnx.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,15 @@ def reverse_op8(ctx, node, name, args):
14541454
node.input[0] = node.input[1]
14551455
node.input[1] = tmp
14561456

1457+
def reverse_op9(ctx, node, name, args):
1458+
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
1459+
# we cannot easily construct reverse_sequence equivalence in opset 9, so we will not support it
1460+
# here. Actually using loops to do that is kind of meaningless since there will be performance
1461+
# issue there for sure.
1462+
1463+
raise NotImplementedError("ReverseSequence is not supported to convert in OPSET 9,"
1464+
" if possible please try using OPSET 8 instead.")
1465+
14571466

14581467
def shape_op(ctx, node, name, args):
14591468
# out_type output = Shape(T input, @int32|int64 out_type), out_type by default int32
@@ -1824,6 +1833,7 @@ def where_op(ctx, node, name, args):
18241833
"IsNan": (direct_op, ["IsNaN"]),
18251834
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18261835
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
1836+
"ReverseSequence": (reverse_op9, []),
18271837
"Sign": (sign_op9, []),
18281838
"Sinh": (direct_op, []),
18291839
"Where": (where_op, []),
@@ -2451,8 +2461,6 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
24512461
# onnx requires topological sorting
24522462
topological_sort(g, continue_on_error)
24532463

2454-
g.update_proto()
2455-
24562464
if verbose:
24572465
print("tensorflow ops: {}".format(op_cnt))
24582466
print("tensorflow attr: {}".format(attr_cnt))

0 commit comments

Comments
 (0)