Skip to content

Commit 42076bc

Browse files
Added constant folding for Slice nodes following a shape node of known shape
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent de21e8a commit 42076bc

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tf2onnx/tf_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,32 @@ def compress_graph_def(graph_def):
140140
tensor.tensor_content = b''
141141
return const_node_values
142142

143+
def get_index_from_strided_slice_of_shape(node, outputs_to_values):
144+
"""Returns the index of the dimension that the strided slice is reading from the shape node or None"""
145+
attr_vals = {
146+
'shrink_axis_mask': 1,
147+
'ellipsis_mask': 0,
148+
'begin_mask': 0,
149+
'new_axis_mask': 0,
150+
'end_mask': 0
151+
}
152+
for a in node.node_def.attr:
153+
if a in attr_vals:
154+
i = get_tf_node_attr(node, a)
155+
if i != attr_vals[a]:
156+
return None
157+
i1 = outputs_to_values.get(node.inputs[1].name)
158+
i2 = outputs_to_values.get(node.inputs[2].name)
159+
i3 = outputs_to_values.get(node.inputs[3].name)
160+
if i1 is None or i2 is None or i3 is None:
161+
return None
162+
if i1.shape != (1,) or i2.shape != (1,) or i3.shape != (1,):
163+
return None
164+
i1, i2, i3 = i1[0], i2[0], i3[0]
165+
if i1 + 1 != i2 or i3 != 1:
166+
return None
167+
return i1
168+
143169
def compute_const_folding_using_tf(g, const_node_values):
144170
"""Find nodes with constant inputs and compute their values using TF"""
145171
if const_node_values is None:
@@ -149,6 +175,8 @@ def compute_const_folding_using_tf(g, const_node_values):
149175
ops = g.get_operations()
150176
outputs_to_values = {}
151177
outputs_to_dtypes = {}
178+
outputs_to_shapes = {}
179+
shape_node_outputs = {}
152180

153181
for node in ops:
154182
# Load values of constants. Use const_node_values if possible
@@ -158,6 +186,14 @@ def compute_const_folding_using_tf(g, const_node_values):
158186
tensor.tensor_content = const_node_values[node.name]
159187
outputs_to_values[node.outputs[0].name] = get_tf_tensor_data(tensor)
160188
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
189+
for out in node.outputs:
190+
outputs_to_shapes[out.name] = get_tf_tensor_shape(out)
191+
192+
for node in ops:
193+
if node.type == "Shape":
194+
shape = outputs_to_shapes.get(node.inputs[0].name)
195+
if shape is not None:
196+
shape_node_outputs[node.outputs[0].name] = shape
161197

162198
unneeded_outputs = set()
163199
progress = True
@@ -167,6 +203,14 @@ def compute_const_folding_using_tf(g, const_node_values):
167203
# Find ops with constant inputs and compute their values
168204
input_names = [i.name for i in node.inputs]
169205
output_names = [i.name for i in node.outputs]
206+
if node.type == 'StridedSlice' and input_names[0] in shape_node_outputs \
207+
and output_names[0] not in outputs_to_values:
208+
shape = shape_node_outputs[input_names[0]]
209+
i = get_index_from_strided_slice_of_shape(node, outputs_to_values)
210+
if i is not None and 0 <= i < len(shape) and shape[i] is not None:
211+
outputs_to_values[output_names[0]] = np.array(shape[i])
212+
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
213+
progress = True
170214
can_fold = node.type not in ['Enter']
171215
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
172216
# We can only fold nodes with a single output

0 commit comments

Comments
 (0)