Skip to content

Commit b54e43a

Browse files
Fix issue with tf1 loops with tensor array read last pattern (#1577)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 8a61c99 commit b54e43a

File tree

5 files changed

+64
-23
lines changed

5 files changed

+64
-23
lines changed

ci_build/azure_pipelines/keras2onnx_unit_test.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ jobs:
4040
INSTALL_ORT: pip install onnxruntime==1.8.0
4141

4242
############ Pure Keras Unit Tests ############
43-
# Keras-Py36-tf1.15.0: # Failing, will enable soon.
44-
# python.version: '3.6'
45-
# ONNX_PATH: onnx==1.5.0
46-
# KERAS: keras==2.2.5
47-
# TENSORFLOW_PATH: tensorflow==1.15.0
48-
# INSTALL_ORT: pip install onnxruntime==1.8.0
43+
Keras-Py36-tf1.15.0:
44+
python.version: '3.6'
45+
ONNX_PATH: onnx==1.5.0
46+
KERAS: keras==2.2.5
47+
TENSORFLOW_PATH: tensorflow==1.15.0
48+
INSTALL_ORT: pip install onnxruntime==1.8.0
4949

5050
Keras-Py37-tf1.15.0:
5151
python.version: '3.7'

tf2onnx/rewriter/loop_rewriter.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def create_context(self):
3030

3131
def run(self):
3232
logger.debug("enter loop rewriter")
33-
return self.run_internal()
33+
return self.run_internal(allow_ta_read_last=True)
3434

3535
def need_rewrite(self, context):
3636
return True
@@ -93,6 +93,10 @@ def rewrite(self, context):
9393
logger.error("failed to create loop node during rewrite")
9494
return REWRITER_RESULT.FAIL
9595

96+
for unneeded_scan_variable in loop_props.unneeded_scan_variables.values():
97+
self.g.replace_all_inputs(unneeded_scan_variable.exit_output.id,
98+
unneeded_scan_variable.equivalent_state_variable.exit_output.id)
99+
96100
logger.debug("rewrite successfully")
97101
return REWRITER_RESULT.OK
98102

@@ -152,7 +156,9 @@ def _create_loop_node(self, context, loop_props, init_cond_output, branches=None
152156
n = self.g.get_node_by_output(tensor_value_info.id)
153157
self.g.remove_node(n.name)
154158
else:
155-
loop_outputs.append(utils.make_name("unused_loop_output_"))
159+
output_name = utils.make_name("unused_loop_output_")
160+
tensor_value_info.id = output_name
161+
loop_outputs.append(output_name)
156162
loop_output_shapes.append([-1])
157163
loop_output_dtypes.append(None)
158164

tf2onnx/rewriter/loop_rewriter_base.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tf2onnx import utils
1111
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
1212
from tf2onnx.utils import is_tf_loopcond_op, is_tf_tensor_array_op
13-
from tf2onnx.utils import is_tf_tensor_array_gather_op, is_tf_tensor_array_write_op
13+
from tf2onnx.utils import is_tf_tensor_array_gather_op, is_tf_tensor_array_write_op, is_tf_tensor_array_read_op
1414
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT
1515
from tf2onnx.utils import TensorValueInfo
1616

@@ -47,6 +47,7 @@ def __init__(self):
4747
# used as initial input for more than one Enter nodes.
4848
self.state_variables = OrderedDict()
4949
self.scan_variables = OrderedDict()
50+
self.unneeded_scan_variables = OrderedDict()
5051

5152
self.tensor_array_inputs = [] # list of type InputTensorArray
5253

@@ -55,10 +56,14 @@ def add_variable(self, var):
5556
"variable %s already exists as scan variable.", var.enter_name)
5657
utils.make_sure(var.enter_name not in self.state_variables,
5758
"variable %s already exists as state variable.", var.enter_name)
58-
if not var.is_tensor_array:
59-
self.state_variables[var.enter_name] = var
60-
else:
59+
if var.tensor_array_type == TensorArrayVariableType.READ_LAST:
60+
# If the variable just returns the last value of the constructed tensor array, it doesn't need to be
61+
# a scan output
62+
self.unneeded_scan_variables[var.enter_name] = var
63+
elif var.tensor_array_type == TensorArrayVariableType.GATHER_ALL:
6164
self.scan_variables[var.enter_name] = var
65+
else:
66+
self.state_variables[var.enter_name] = var
6267

6368
def get_variables(self, checker):
6469
if not checker:
@@ -69,6 +74,7 @@ def get_variables(self, checker):
6974
def all_variables(self):
7075
items = self.state_variables.copy()
7176
items.update(self.scan_variables)
77+
items.update(self.unneeded_scan_variables)
7278
return items
7379

7480
# state inputs and outputs are in pairs, even though some outputs are not depending on corresponding input,
@@ -111,6 +117,16 @@ def scan_inputs(self):
111117
def scan_inputs_initial_values(self):
112118
return [i.data_input_id for i in self.tensor_array_inputs]
113119

120+
def has_variable_with_ta_type(self, tensor_array_type):
121+
for variable in self.all_variables.values():
122+
if variable.tensor_array_type == tensor_array_type:
123+
return True
124+
return False
125+
126+
class TensorArrayVariableType:
127+
GATHER_ALL = "GATHER_ALL"
128+
READ_LAST = "READ_LAST"
129+
114130
class LoopVariable(object):
115131
"""In TensorFlow loop, all loop variables are listed both in iteration body graph's inputs, and outputs.
116132
Loop (state variable 1, state variable 2) {
@@ -131,7 +147,7 @@ class LoopVariable(object):
131147
(e.g. switch_true_identity_output.id).
132148
"""
133149
def __init__(self, enter_name, enter_input_id, next_iteration_input_id,
134-
switch_true_identity_output_id, exit_output_id, is_tensor_array, ta_index_id, g):
150+
switch_true_identity_output_id, exit_output_id, tensor_array_type, ta_index_id, g):
135151
self.enter_name = enter_name
136152
self.enter_input_id = enter_input_id
137153

@@ -150,7 +166,7 @@ def __init__(self, enter_name, enter_input_id, next_iteration_input_id,
150166
self.exit_output = TensorValueInfo(exit_output_id, g)
151167

152168
# only applicable for tensor array variable
153-
self.is_tensor_array = is_tensor_array
169+
self.tensor_array_type = tensor_array_type
154170
# todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration.
155171
# then we can be sure this is equivalent to scan output behavior.
156172
self.ta_index_id = ta_index_id
@@ -189,7 +205,7 @@ def need_rewrite(self, context):
189205
def rewrite(self, context):
190206
return REWRITER_RESULT.FAIL
191207

192-
def run_internal(self):
208+
def run_internal(self, allow_ta_read_last=False):
193209
loopcond_ops = []
194210
for op in self.g.get_nodes():
195211
if is_tf_loopcond_op(op):
@@ -201,7 +217,11 @@ def run_internal(self):
201217
context = self.create_context()
202218
context.loop_cond = op
203219

204-
self._check_in_read_only_mode(context)
220+
self._check_in_read_only_mode(context) # parses loop variables
221+
222+
loop_properties = context.loop_properties
223+
if not allow_ta_read_last and loop_properties.has_variable_with_ta_type(TensorArrayVariableType.READ_LAST):
224+
continue
205225

206226
if self.need_rewrite(context):
207227
# cut off connection between cell/cond graphs and useless nodes like Merge, NextIteration.
@@ -241,6 +261,12 @@ def _parse_loop_variables(self, context):
241261
loop_var = self._get_loop_var_from_switch(s)
242262
context.loop_properties.add_variable(loop_var)
243263

264+
for unneeded_scan_variable in context.loop_properties.unneeded_scan_variables.values():
265+
for state_variable in context.loop_properties.state_variables.values():
266+
if unneeded_scan_variable.next_iteration_input.id == state_variable.next_iteration_input.id:
267+
unneeded_scan_variable.equivalent_state_variable = state_variable
268+
break
269+
244270
def _parse_input_ta(self, context):
245271
graph_inputs = [v.switch_true_identity_output.id for v in context.loop_properties.all_variables.values()
246272
if v.switch_true_identity_output.id]
@@ -313,7 +339,7 @@ def _cut_off_connection_for_cell(self, context):
313339
n = self.g.get_node_by_output(val.switch_true_identity_output.id)
314340
self.g.remove_node(n.name)
315341

316-
if val.is_tensor_array:
342+
if val.tensor_array_type == TensorArrayVariableType.GATHER_ALL:
317343
# connect NextIteration to an invalid node, to cut off an ending node of the cell.
318344
ta_write_nodes = [n for n in self.g.get_nodes() if is_tf_tensor_array_write_op(n)]
319345
self.g.replace_all_inputs(val.next_iteration_input.id, INVALID_INPUT_ID, ops=ta_write_nodes)
@@ -382,10 +408,9 @@ def _get_loop_var_from_switch(self, switch_node):
382408
else:
383409
raise ValueError("unexpected number of switch false consumers")
384410

385-
is_ta = False
411+
ta_type = None
386412
ta_index_id = None
387413
if is_tf_tensor_array_op(self.g.get_node_by_output(target_node_input_id)):
388-
is_ta = True
389414

390415
ta_write_node = self.g.get_node_by_output(last_iteration_output_id)
391416
utils.make_sure(is_tf_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op")
@@ -396,13 +421,19 @@ def _get_loop_var_from_switch(self, switch_node):
396421
# ta.write(), then ta.stack(), because this is the most frequent usage pattern.
397422
if exit_output_id:
398423
exit_consumers = self.g.find_output_consumers(exit_output_id)
399-
ta_gather_node = [n for n in exit_consumers if is_tf_tensor_array_gather_op(n)][0]
424+
ta_access_node = [n for n in exit_consumers if is_tf_tensor_array_gather_op(n) or \
425+
is_tf_tensor_array_read_op(n)][0]
426+
427+
if is_tf_tensor_array_read_op(ta_access_node):
428+
ta_type = TensorArrayVariableType.READ_LAST
429+
else:
430+
ta_type = TensorArrayVariableType.GATHER_ALL
400431

401432
# update exit output id, treat the gather output as ta's output
402-
exit_output_id = ta_gather_node.output[0]
433+
exit_output_id = ta_access_node.output[0]
403434

404435
loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id,
405-
switch_true_identity_output, exit_output_id, is_ta, ta_index_id, self.g)
436+
switch_true_identity_output, exit_output_id, ta_type, ta_index_id, self.g)
406437

407438
return loop_var
408439

tf2onnx/rewriter/rnn_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def parse_rnn_loop(graph, loop_properties, rnn_scope, while_context_scope):
309309
1. iteration counter does not exist in tf1.4 or earlier versions
310310
2. if dynamic_rnn's first input is not consumed, output ta does not exist.
311311
"""
312+
from tf2onnx.rewriter.loop_rewriter_base import TensorArrayVariableType # pylint: disable=import-outside-toplevel
312313
time_name = rnn_scope + "time"
313314
ta_array_name_prefix = rnn_scope + "dynamic_rnn/output_"
314315
iteration_counter_name = while_context_scope + "iteration_counter"
@@ -319,7 +320,7 @@ def parse_rnn_loop(graph, loop_properties, rnn_scope, while_context_scope):
319320
iteration_var = None
320321
for val in loop_properties.all_variables.values():
321322
enter_input_node = graph.get_node_by_output(val.enter_input_id)
322-
if val.is_tensor_array:
323+
if val.tensor_array_type == TensorArrayVariableType.GATHER_ALL:
323324
ta_name = enter_input_node.get_attr("tensor_array_name").s.decode("utf-8")
324325
if not ta_name.startswith(ta_array_name_prefix):
325326
is_rnn_out_ta = False

tf2onnx/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,9 @@ def is_tf_tensor_array_gather_op(op):
501501
def is_tf_tensor_array_write_op(op):
502502
return op.type in ("TensorArrayWriteV2", "TensorArrayWriteV3")
503503

504+
def is_tf_tensor_array_read_op(op):
505+
return op.type in ("TensorArrayReadV2", "TensorArrayReadV3")
506+
504507

505508
def is_tf_tensor_array_op(op):
506509
return op.type in ("TensorArrayV2", "TensorArrayV3")

0 commit comments

Comments
 (0)