@@ -140,6 +140,32 @@ def compress_graph_def(graph_def):
140
140
tensor .tensor_content = b''
141
141
return const_node_values
142
142
143
+ def get_index_from_strided_slice_of_shape (node , outputs_to_values ):
144
+ """Returns the index of the dimension that the strided slice is reading from the shape node or None"""
145
+ attr_vals = {
146
+ 'shrink_axis_mask' : 1 ,
147
+ 'ellipsis_mask' : 0 ,
148
+ 'begin_mask' : 0 ,
149
+ 'new_axis_mask' : 0 ,
150
+ 'end_mask' : 0
151
+ }
152
+ for a in node .node_def .attr :
153
+ if a in attr_vals :
154
+ i = get_tf_node_attr (node , a )
155
+ if i != attr_vals [a ]:
156
+ return None
157
+ i1 = outputs_to_values .get (node .inputs [1 ].name )
158
+ i2 = outputs_to_values .get (node .inputs [2 ].name )
159
+ i3 = outputs_to_values .get (node .inputs [3 ].name )
160
+ if i1 is None or i2 is None or i3 is None :
161
+ return None
162
+ if i1 .shape != (1 ,) or i2 .shape != (1 ,) or i3 .shape != (1 ,):
163
+ return None
164
+ i1 , i2 , i3 = i1 [0 ], i2 [0 ], i3 [0 ]
165
+ if i1 + 1 != i2 or i3 != 1 :
166
+ return None
167
+ return i1
168
+
143
169
def compute_const_folding_using_tf (g , const_node_values ):
144
170
"""Find nodes with constant inputs and compute their values using TF"""
145
171
if const_node_values is None :
@@ -149,6 +175,8 @@ def compute_const_folding_using_tf(g, const_node_values):
149
175
ops = g .get_operations ()
150
176
outputs_to_values = {}
151
177
outputs_to_dtypes = {}
178
+ outputs_to_shapes = {}
179
+ shape_node_outputs = {}
152
180
153
181
for node in ops :
154
182
# Load values of constants. Use const_node_values if possible
@@ -158,6 +186,14 @@ def compute_const_folding_using_tf(g, const_node_values):
158
186
tensor .tensor_content = const_node_values [node .name ]
159
187
outputs_to_values [node .outputs [0 ].name ] = get_tf_tensor_data (tensor )
160
188
outputs_to_dtypes [node .outputs [0 ].name ] = node .outputs [0 ].dtype
189
+ for out in node .outputs :
190
+ outputs_to_shapes [out .name ] = get_tf_tensor_shape (out )
191
+
192
+ for node in ops :
193
+ if node .type == "Shape" :
194
+ shape = outputs_to_shapes .get (node .inputs [0 ].name )
195
+ if shape is not None :
196
+ shape_node_outputs [node .outputs [0 ].name ] = shape
161
197
162
198
unneeded_outputs = set ()
163
199
progress = True
@@ -167,6 +203,14 @@ def compute_const_folding_using_tf(g, const_node_values):
167
203
# Find ops with constant inputs and compute their values
168
204
input_names = [i .name for i in node .inputs ]
169
205
output_names = [i .name for i in node .outputs ]
206
+ if node .type == 'StridedSlice' and input_names [0 ] in shape_node_outputs \
207
+ and output_names [0 ] not in outputs_to_values :
208
+ shape = shape_node_outputs [input_names [0 ]]
209
+ i = get_index_from_strided_slice_of_shape (node , outputs_to_values )
210
+ if i is not None and 0 <= i < len (shape ) and shape [i ] is not None :
211
+ outputs_to_values [output_names [0 ]] = np .array (shape [i ])
212
+ outputs_to_dtypes [node .outputs [0 ].name ] = node .outputs [0 ].dtype
213
+ progress = True
170
214
can_fold = node .type not in ['Enter' ]
171
215
can_fold = can_fold and len (input_names ) > 0 and all (inp in outputs_to_values for inp in input_names )
172
216
# We can only fold nodes with a single output
0 commit comments