Skip to content

Commit 0e2ea55

Browse files
Merge pull request #1145 from onnx/tom/foldSliceFromConstShape
Added constant folding for Slice nodes following a shape node of know…
2 parents 57ccfb3 + 2077779 commit 0e2ea55

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

tests/test_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,22 @@ def func(x):
12131213
return tf.identity(x_, name=_TFOUTPUT)
12141214
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1})
12151215

1216+
def test_slice_from_shape_const_fold(self):
1217+
x_val = np.array([4, 3], dtype=np.int64)
1218+
x_shape = np.array([-1, 3], dtype=np.int64)
1219+
def func(x):
1220+
z = tf.zeros(x)
1221+
x = tf.reshape(z, tf.constant(x_shape))
1222+
s = tf.shape(x)
1223+
t1 = tf.constant([1], dtype=tf.int32)
1224+
t2 = tf.constant([2], dtype=tf.int32)
1225+
y = tf.strided_slice(s, t1, t2, shrink_axis_mask=1)
1226+
return tf.identity(y, name=_TFOUTPUT)
1227+
def graph_validator(g):
1228+
# After constant folding just an input and const output node remain
1229+
return len(g.get_nodes()) == 2
1230+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, graph_validator=graph_validator)
1231+
12161232
def test_slice(self):
12171233
x_val = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32)
12181234
def func(x):

tf2onnx/tf_utils.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from onnx import helper, onnx_pb, numpy_helper
2222

23-
from tf2onnx.utils import make_sure, is_tf_const_op, port_name
23+
from tf2onnx.utils import make_sure, is_tf_const_op, port_name, map_onnx_to_numpy_type
2424
from . import logging
2525

2626
logger = logging.getLogger(__name__)
@@ -140,15 +140,44 @@ def compress_graph_def(graph_def):
140140
tensor.tensor_content = b''
141141
return const_node_values
142142

143-
def compute_const_folding_using_tf(g, const_node_values):
143+
def get_index_from_strided_slice_of_shape(node, outputs_to_values):
144+
"""Returns the index of the dimension that the strided slice is reading from the shape node or None"""
145+
attr_vals = {
146+
'shrink_axis_mask': 1,
147+
'ellipsis_mask': 0,
148+
'begin_mask': 0,
149+
'new_axis_mask': 0,
150+
'end_mask': 0
151+
}
152+
for a in node.node_def.attr:
153+
if a in attr_vals:
154+
i = get_tf_node_attr(node, a)
155+
if i != attr_vals[a]:
156+
return None
157+
i1 = outputs_to_values.get(node.inputs[1].name)
158+
i2 = outputs_to_values.get(node.inputs[2].name)
159+
i3 = outputs_to_values.get(node.inputs[3].name)
160+
if i1 is None or i2 is None or i3 is None:
161+
return None
162+
if i1.shape != (1,) or i2.shape != (1,) or i3.shape != (1,):
163+
return None
164+
i1, i2, i3 = i1[0], i2[0], i3[0]
165+
if i1 + 1 != i2 or i3 != 1:
166+
return None
167+
return i1
168+
169+
def compute_const_folding_using_tf(g, const_node_values, graph_outputs):
144170
"""Find nodes with constant inputs and compute their values using TF"""
145171
if const_node_values is None:
146172
const_node_values = {}
173+
graph_outputs = set(graph_outputs)
147174
from tf2onnx.tf_loader import tf_session, tf_placeholder # pylint: disable=import-outside-toplevel
148175

149176
ops = g.get_operations()
150177
outputs_to_values = {}
151178
outputs_to_dtypes = {}
179+
outputs_to_shapes = {}
180+
shape_node_outputs = {}
152181

153182
for node in ops:
154183
# Load values of constants. Use const_node_values if possible
@@ -158,6 +187,14 @@ def compute_const_folding_using_tf(g, const_node_values):
158187
tensor.tensor_content = const_node_values[node.name]
159188
outputs_to_values[node.outputs[0].name] = get_tf_tensor_data(tensor)
160189
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
190+
for out in node.outputs:
191+
outputs_to_shapes[out.name] = get_tf_tensor_shape(out)
192+
193+
for node in ops:
194+
if node.type == "Shape":
195+
shape = outputs_to_shapes.get(node.inputs[0].name)
196+
if shape is not None:
197+
shape_node_outputs[node.outputs[0].name] = shape
161198

162199
unneeded_outputs = set()
163200
progress = True
@@ -167,12 +204,21 @@ def compute_const_folding_using_tf(g, const_node_values):
167204
# Find ops with constant inputs and compute their values
168205
input_names = [i.name for i in node.inputs]
169206
output_names = [i.name for i in node.outputs]
207+
if node.type == 'StridedSlice' and input_names[0] in shape_node_outputs \
208+
and output_names[0] not in outputs_to_values:
209+
shape = shape_node_outputs[input_names[0]]
210+
i = get_index_from_strided_slice_of_shape(node, outputs_to_values)
211+
if i is not None and 0 <= i < len(shape) and shape[i] is not None:
212+
np_dtype = map_onnx_to_numpy_type(map_tf_dtype(node.outputs[0].dtype))
213+
outputs_to_values[output_names[0]] = np.array(shape[i], dtype=np_dtype)
214+
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
215+
progress = True
170216
can_fold = node.type not in ['Enter']
171217
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
172218
# We can only fold nodes with a single output
173219
can_fold = can_fold and len(output_names) == 1 and output_names[0] not in outputs_to_values
174220
# Skip if value already computed, used, and discarded
175-
can_fold = can_fold and output_names[0] not in unneeded_outputs
221+
can_fold = can_fold and output_names[0] not in unneeded_outputs and output_names[0] not in graph_outputs
176222
if can_fold:
177223
# Make a mini graph containing just the node to fold
178224
g2 = tf.Graph()

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
409409
if target is None:
410410
target = constants.DEFAULT_TARGET
411411

412-
outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(tf_graph, const_node_values)
412+
outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(tf_graph, const_node_values, output_names)
413413

414414
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
415415
tensorflow_to_onnx(tf_graph, shape_override, const_node_values)

0 commit comments

Comments
 (0)