|
20 | 20 |
|
21 | 21 | from onnx import helper, onnx_pb, numpy_helper
|
22 | 22 |
|
23 |
| -from tf2onnx.utils import make_sure, is_tf_const_op, port_name |
| 23 | +from tf2onnx.utils import make_sure, is_tf_const_op, port_name, map_onnx_to_numpy_type |
24 | 24 | from . import logging
|
25 | 25 |
|
26 | 26 | logger = logging.getLogger(__name__)
|
@@ -166,10 +166,11 @@ def get_index_from_strided_slice_of_shape(node, outputs_to_values):
|
166 | 166 | return None
|
167 | 167 | return i1
|
168 | 168 |
|
169 |
| -def compute_const_folding_using_tf(g, const_node_values): |
| 169 | +def compute_const_folding_using_tf(g, const_node_values, graph_outputs): |
170 | 170 | """Find nodes with constant inputs and compute their values using TF"""
|
171 | 171 | if const_node_values is None:
|
172 | 172 | const_node_values = {}
|
| 173 | + graph_outputs = set(graph_outputs) |
173 | 174 | from tf2onnx.tf_loader import tf_session, tf_placeholder # pylint: disable=import-outside-toplevel
|
174 | 175 |
|
175 | 176 | ops = g.get_operations()
|
@@ -208,15 +209,16 @@ def compute_const_folding_using_tf(g, const_node_values):
|
208 | 209 | shape = shape_node_outputs[input_names[0]]
|
209 | 210 | i = get_index_from_strided_slice_of_shape(node, outputs_to_values)
|
210 | 211 | if i is not None and 0 <= i < len(shape) and shape[i] is not None:
|
211 |
| - outputs_to_values[output_names[0]] = np.array(shape[i]) |
| 212 | + np_dtype = map_onnx_to_numpy_type(map_tf_dtype(node.outputs[0].dtype)) |
| 213 | + outputs_to_values[output_names[0]] = np.array(shape[i], dtype=np_dtype) |
212 | 214 | outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
|
213 | 215 | progress = True
|
214 | 216 | can_fold = node.type not in ['Enter']
|
215 | 217 | can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
|
216 | 218 | # We can only fold nodes with a single output
|
217 | 219 | can_fold = can_fold and len(output_names) == 1 and output_names[0] not in outputs_to_values
|
218 | 220 | # Skip if value already computed, used, and discarded
|
219 |
| - can_fold = can_fold and output_names[0] not in unneeded_outputs |
| 221 | + can_fold = can_fold and output_names[0] not in unneeded_outputs and output_names[0] not in graph_outputs |
220 | 222 | if can_fold:
|
221 | 223 | # Make a mini graph containing just the node to fold
|
222 | 224 | g2 = tf.Graph()
|
|
0 commit comments