20
20
21
21
from onnx import helper , onnx_pb , numpy_helper
22
22
23
- from tf2onnx .utils import make_sure , is_tf_const_op , port_name
23
+ from tf2onnx .utils import make_sure , is_tf_const_op , port_name , map_onnx_to_numpy_type
24
24
from . import logging
25
25
26
26
logger = logging .getLogger (__name__ )
@@ -140,15 +140,44 @@ def compress_graph_def(graph_def):
140
140
tensor .tensor_content = b''
141
141
return const_node_values
142
142
143
- def compute_const_folding_using_tf (g , const_node_values ):
143
+ def get_index_from_strided_slice_of_shape (node , outputs_to_values ):
144
+ """Returns the index of the dimension that the strided slice is reading from the shape node or None"""
145
+ attr_vals = {
146
+ 'shrink_axis_mask' : 1 ,
147
+ 'ellipsis_mask' : 0 ,
148
+ 'begin_mask' : 0 ,
149
+ 'new_axis_mask' : 0 ,
150
+ 'end_mask' : 0
151
+ }
152
+ for a in node .node_def .attr :
153
+ if a in attr_vals :
154
+ i = get_tf_node_attr (node , a )
155
+ if i != attr_vals [a ]:
156
+ return None
157
+ i1 = outputs_to_values .get (node .inputs [1 ].name )
158
+ i2 = outputs_to_values .get (node .inputs [2 ].name )
159
+ i3 = outputs_to_values .get (node .inputs [3 ].name )
160
+ if i1 is None or i2 is None or i3 is None :
161
+ return None
162
+ if i1 .shape != (1 ,) or i2 .shape != (1 ,) or i3 .shape != (1 ,):
163
+ return None
164
+ i1 , i2 , i3 = i1 [0 ], i2 [0 ], i3 [0 ]
165
+ if i1 + 1 != i2 or i3 != 1 :
166
+ return None
167
+ return i1
168
+
169
+ def compute_const_folding_using_tf (g , const_node_values , graph_outputs ):
144
170
"""Find nodes with constant inputs and compute their values using TF"""
145
171
if const_node_values is None :
146
172
const_node_values = {}
173
+ graph_outputs = set (graph_outputs )
147
174
from tf2onnx .tf_loader import tf_session , tf_placeholder # pylint: disable=import-outside-toplevel
148
175
149
176
ops = g .get_operations ()
150
177
outputs_to_values = {}
151
178
outputs_to_dtypes = {}
179
+ outputs_to_shapes = {}
180
+ shape_node_outputs = {}
152
181
153
182
for node in ops :
154
183
# Load values of constants. Use const_node_values if possible
@@ -158,6 +187,14 @@ def compute_const_folding_using_tf(g, const_node_values):
158
187
tensor .tensor_content = const_node_values [node .name ]
159
188
outputs_to_values [node .outputs [0 ].name ] = get_tf_tensor_data (tensor )
160
189
outputs_to_dtypes [node .outputs [0 ].name ] = node .outputs [0 ].dtype
190
+ for out in node .outputs :
191
+ outputs_to_shapes [out .name ] = get_tf_tensor_shape (out )
192
+
193
+ for node in ops :
194
+ if node .type == "Shape" :
195
+ shape = outputs_to_shapes .get (node .inputs [0 ].name )
196
+ if shape is not None :
197
+ shape_node_outputs [node .outputs [0 ].name ] = shape
161
198
162
199
unneeded_outputs = set ()
163
200
progress = True
@@ -167,12 +204,21 @@ def compute_const_folding_using_tf(g, const_node_values):
167
204
# Find ops with constant inputs and compute their values
168
205
input_names = [i .name for i in node .inputs ]
169
206
output_names = [i .name for i in node .outputs ]
207
+ if node .type == 'StridedSlice' and input_names [0 ] in shape_node_outputs \
208
+ and output_names [0 ] not in outputs_to_values :
209
+ shape = shape_node_outputs [input_names [0 ]]
210
+ i = get_index_from_strided_slice_of_shape (node , outputs_to_values )
211
+ if i is not None and 0 <= i < len (shape ) and shape [i ] is not None :
212
+ np_dtype = map_onnx_to_numpy_type (map_tf_dtype (node .outputs [0 ].dtype ))
213
+ outputs_to_values [output_names [0 ]] = np .array (shape [i ], dtype = np_dtype )
214
+ outputs_to_dtypes [node .outputs [0 ].name ] = node .outputs [0 ].dtype
215
+ progress = True
170
216
can_fold = node .type not in ['Enter' ]
171
217
can_fold = can_fold and len (input_names ) > 0 and all (inp in outputs_to_values for inp in input_names )
172
218
# We can only fold nodes with a single output
173
219
can_fold = can_fold and len (output_names ) == 1 and output_names [0 ] not in outputs_to_values
174
220
# Skip if value already computed, used, and discarded
175
- can_fold = can_fold and output_names [0 ] not in unneeded_outputs
221
+ can_fold = can_fold and output_names [0 ] not in unneeded_outputs and output_names [ 0 ] not in graph_outputs
176
222
if can_fold :
177
223
# Make a mini graph containing just the node to fold
178
224
g2 = tf .Graph ()
0 commit comments