10
10
from tf2onnx import utils
11
11
from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
12
12
from tf2onnx .utils import is_tf_loopcond_op , is_tf_tensor_array_op
13
- from tf2onnx .utils import is_tf_tensor_array_gather_op , is_tf_tensor_array_write_op
13
+ from tf2onnx .utils import is_tf_tensor_array_gather_op , is_tf_tensor_array_write_op , is_tf_tensor_array_read_op
14
14
from tf2onnx .rewriter .rnn_utils import REWRITER_RESULT
15
15
from tf2onnx .utils import TensorValueInfo
16
16
@@ -47,6 +47,7 @@ def __init__(self):
47
47
# used as initial input for more than one Enter nodes.
48
48
self .state_variables = OrderedDict ()
49
49
self .scan_variables = OrderedDict ()
50
+ self .unneeded_scan_variables = OrderedDict ()
50
51
51
52
self .tensor_array_inputs = [] # list of type InputTensorArray
52
53
@@ -55,10 +56,14 @@ def add_variable(self, var):
55
56
"variable %s already exists as scan variable." , var .enter_name )
56
57
utils .make_sure (var .enter_name not in self .state_variables ,
57
58
"variable %s already exists as state variable." , var .enter_name )
58
- if not var .is_tensor_array :
59
- self .state_variables [var .enter_name ] = var
60
- else :
59
+ if var .tensor_array_type == TensorArrayVariableType .READ_LAST :
60
+ # If the variable just returns the last value of the constructed tensor array, it doesn't need to be
61
+ # a scan output
62
+ self .unneeded_scan_variables [var .enter_name ] = var
63
+ elif var .tensor_array_type == TensorArrayVariableType .GATHER_ALL :
61
64
self .scan_variables [var .enter_name ] = var
65
+ else :
66
+ self .state_variables [var .enter_name ] = var
62
67
63
68
def get_variables (self , checker ):
64
69
if not checker :
@@ -69,6 +74,7 @@ def get_variables(self, checker):
69
74
def all_variables (self ):
70
75
items = self .state_variables .copy ()
71
76
items .update (self .scan_variables )
77
+ items .update (self .unneeded_scan_variables )
72
78
return items
73
79
74
80
# state inputs and outputs are in pairs, even though some outputs are not depending on corresponding input,
@@ -111,6 +117,16 @@ def scan_inputs(self):
111
117
def scan_inputs_initial_values (self ):
112
118
return [i .data_input_id for i in self .tensor_array_inputs ]
113
119
120
+ def has_variable_with_ta_type (self , tensor_array_type ):
121
+ for variable in self .all_variables .values ():
122
+ if variable .tensor_array_type == tensor_array_type :
123
+ return True
124
+ return False
125
+
126
+ class TensorArrayVariableType :
127
+ GATHER_ALL = "GATHER_ALL"
128
+ READ_LAST = "READ_LAST"
129
+
114
130
class LoopVariable (object ):
115
131
"""In TensorFlow loop, all loop variables are listed both in iteration body graph's inputs, and outputs.
116
132
Loop (state variable 1, state variable 2) {
@@ -131,7 +147,7 @@ class LoopVariable(object):
131
147
(e.g. switch_true_identity_output.id).
132
148
"""
133
149
def __init__ (self , enter_name , enter_input_id , next_iteration_input_id ,
134
- switch_true_identity_output_id , exit_output_id , is_tensor_array , ta_index_id , g ):
150
+ switch_true_identity_output_id , exit_output_id , tensor_array_type , ta_index_id , g ):
135
151
self .enter_name = enter_name
136
152
self .enter_input_id = enter_input_id
137
153
@@ -150,7 +166,7 @@ def __init__(self, enter_name, enter_input_id, next_iteration_input_id,
150
166
self .exit_output = TensorValueInfo (exit_output_id , g )
151
167
152
168
# only applicable for tensor array variable
153
- self .is_tensor_array = is_tensor_array
169
+ self .tensor_array_type = tensor_array_type
154
170
# todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration.
155
171
# then we can be sure this is equivalent to scan output behavior.
156
172
self .ta_index_id = ta_index_id
@@ -189,7 +205,7 @@ def need_rewrite(self, context):
189
205
def rewrite (self , context ):
190
206
return REWRITER_RESULT .FAIL
191
207
192
- def run_internal (self ):
208
+ def run_internal (self , allow_ta_read_last = False ):
193
209
loopcond_ops = []
194
210
for op in self .g .get_nodes ():
195
211
if is_tf_loopcond_op (op ):
@@ -201,7 +217,11 @@ def run_internal(self):
201
217
context = self .create_context ()
202
218
context .loop_cond = op
203
219
204
- self ._check_in_read_only_mode (context )
220
+ self ._check_in_read_only_mode (context ) # parses loop variables
221
+
222
+ loop_properties = context .loop_properties
223
+ if not allow_ta_read_last and loop_properties .has_variable_with_ta_type (TensorArrayVariableType .READ_LAST ):
224
+ continue
205
225
206
226
if self .need_rewrite (context ):
207
227
# cut off connection between cell/cond graphs and useless nodes like Merge, NextIteration.
@@ -241,6 +261,12 @@ def _parse_loop_variables(self, context):
241
261
loop_var = self ._get_loop_var_from_switch (s )
242
262
context .loop_properties .add_variable (loop_var )
243
263
264
+ for unneeded_scan_variable in context .loop_properties .unneeded_scan_variables .values ():
265
+ for state_variable in context .loop_properties .state_variables .values ():
266
+ if unneeded_scan_variable .next_iteration_input .id == state_variable .next_iteration_input .id :
267
+ unneeded_scan_variable .equivalent_state_variable = state_variable
268
+ break
269
+
244
270
def _parse_input_ta (self , context ):
245
271
graph_inputs = [v .switch_true_identity_output .id for v in context .loop_properties .all_variables .values ()
246
272
if v .switch_true_identity_output .id ]
@@ -313,7 +339,7 @@ def _cut_off_connection_for_cell(self, context):
313
339
n = self .g .get_node_by_output (val .switch_true_identity_output .id )
314
340
self .g .remove_node (n .name )
315
341
316
- if val .is_tensor_array :
342
+ if val .tensor_array_type == TensorArrayVariableType . GATHER_ALL :
317
343
# connect NextIteration to an invalid node, to cut off an ending node of the cell.
318
344
ta_write_nodes = [n for n in self .g .get_nodes () if is_tf_tensor_array_write_op (n )]
319
345
self .g .replace_all_inputs (val .next_iteration_input .id , INVALID_INPUT_ID , ops = ta_write_nodes )
@@ -382,10 +408,9 @@ def _get_loop_var_from_switch(self, switch_node):
382
408
else :
383
409
raise ValueError ("unexpected number of switch false consumers" )
384
410
385
- is_ta = False
411
+ ta_type = None
386
412
ta_index_id = None
387
413
if is_tf_tensor_array_op (self .g .get_node_by_output (target_node_input_id )):
388
- is_ta = True
389
414
390
415
ta_write_node = self .g .get_node_by_output (last_iteration_output_id )
391
416
utils .make_sure (is_tf_tensor_array_write_op (ta_write_node ), "ta nextiteration is not following ta write op" )
@@ -396,13 +421,19 @@ def _get_loop_var_from_switch(self, switch_node):
396
421
# ta.write(), then ta.stack(), because this is the most frequent usage pattern.
397
422
if exit_output_id :
398
423
exit_consumers = self .g .find_output_consumers (exit_output_id )
399
- ta_gather_node = [n for n in exit_consumers if is_tf_tensor_array_gather_op (n )][0 ]
424
+ ta_access_node = [n for n in exit_consumers if is_tf_tensor_array_gather_op (n ) or \
425
+ is_tf_tensor_array_read_op (n )][0 ]
426
+
427
+ if is_tf_tensor_array_read_op (ta_access_node ):
428
+ ta_type = TensorArrayVariableType .READ_LAST
429
+ else :
430
+ ta_type = TensorArrayVariableType .GATHER_ALL
400
431
401
432
# update exit output id, treat the gather output as ta's output
402
- exit_output_id = ta_gather_node .output [0 ]
433
+ exit_output_id = ta_access_node .output [0 ]
403
434
404
435
loop_var = LoopVariable (enter_node .name , target_node_input_id , last_iteration_output_id ,
405
- switch_true_identity_output , exit_output_id , is_ta , ta_index_id , self .g )
436
+ switch_true_identity_output , exit_output_id , ta_type , ta_index_id , self .g )
406
437
407
438
return loop_var
408
439
0 commit comments