Skip to content

Commit 2077779

Browse files
Added unit test
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 42076bc commit 2077779

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

tests/test_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,22 @@ def func(x):
12131213
return tf.identity(x_, name=_TFOUTPUT)
12141214
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1})
12151215

1216+
def test_slice_from_shape_const_fold(self):
1217+
x_val = np.array([4, 3], dtype=np.int64)
1218+
x_shape = np.array([-1, 3], dtype=np.int64)
1219+
def func(x):
1220+
z = tf.zeros(x)
1221+
x = tf.reshape(z, tf.constant(x_shape))
1222+
s = tf.shape(x)
1223+
t1 = tf.constant([1], dtype=tf.int32)
1224+
t2 = tf.constant([2], dtype=tf.int32)
1225+
y = tf.strided_slice(s, t1, t2, shrink_axis_mask=1)
1226+
return tf.identity(y, name=_TFOUTPUT)
1227+
def graph_validator(g):
1228+
# After constant folding just an input and const output node remain
1229+
return len(g.get_nodes()) == 2
1230+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, graph_validator=graph_validator)
1231+
12161232
def test_slice(self):
12171233
x_val = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32)
12181234
def func(x):

tf2onnx/tf_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from onnx import helper, onnx_pb, numpy_helper
2222

23-
from tf2onnx.utils import make_sure, is_tf_const_op, port_name
23+
from tf2onnx.utils import make_sure, is_tf_const_op, port_name, map_onnx_to_numpy_type
2424
from . import logging
2525

2626
logger = logging.getLogger(__name__)
@@ -166,10 +166,11 @@ def get_index_from_strided_slice_of_shape(node, outputs_to_values):
166166
return None
167167
return i1
168168

169-
def compute_const_folding_using_tf(g, const_node_values):
169+
def compute_const_folding_using_tf(g, const_node_values, graph_outputs):
170170
"""Find nodes with constant inputs and compute their values using TF"""
171171
if const_node_values is None:
172172
const_node_values = {}
173+
graph_outputs = set(graph_outputs)
173174
from tf2onnx.tf_loader import tf_session, tf_placeholder # pylint: disable=import-outside-toplevel
174175

175176
ops = g.get_operations()
@@ -208,15 +209,16 @@ def compute_const_folding_using_tf(g, const_node_values):
208209
shape = shape_node_outputs[input_names[0]]
209210
i = get_index_from_strided_slice_of_shape(node, outputs_to_values)
210211
if i is not None and 0 <= i < len(shape) and shape[i] is not None:
211-
outputs_to_values[output_names[0]] = np.array(shape[i])
212+
np_dtype = map_onnx_to_numpy_type(map_tf_dtype(node.outputs[0].dtype))
213+
outputs_to_values[output_names[0]] = np.array(shape[i], dtype=np_dtype)
212214
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
213215
progress = True
214216
can_fold = node.type not in ['Enter']
215217
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
216218
# We can only fold nodes with a single output
217219
can_fold = can_fold and len(output_names) == 1 and output_names[0] not in outputs_to_values
218220
# Skip if value already computed, used, and discarded
219-
can_fold = can_fold and output_names[0] not in unneeded_outputs
221+
can_fold = can_fold and output_names[0] not in unneeded_outputs and output_names[0] not in graph_outputs
220222
if can_fold:
221223
# Make a mini graph containing just the node to fold
222224
g2 = tf.Graph()

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
409409
if target is None:
410410
target = constants.DEFAULT_TARGET
411411

412-
outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(tf_graph, const_node_values)
412+
outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(tf_graph, const_node_values, output_names)
413413

414414
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
415415
tensorflow_to_onnx(tf_graph, shape_override, const_node_values)

0 commit comments

Comments
 (0)