Skip to content

Commit 1ce0f99

Browse files
committed
filter input to exclude resources, add sigmoid to transpose optimizer
1 parent 26d3507 commit 1ce0f99

File tree

2 files changed

+17
-21
lines changed

2 files changed

+17
-21
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def _initialize_handlers(self):
183183
"Pad": self._pad_handler,
184184
"ReduceMean": self._reducemean_handler,
185185
"Relu": self._simple_through_handler,
186+
"Sigmoid": self._simple_through_handler,
186187
"Shape": self._shape_handler,
187188
"Slice": self._slice_handler,
188189
"Split": self._split_handler,

tf2onnx/tf_loader.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,6 @@ def from_checkpoint(model_path, input_names, output_names):
180180

181181
def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures):
182182
"""Load tensorflow graph from saved_model."""
183-
# make sure we start with clean default graph
184-
inputs = {}
185-
outputs = {}
186183

187184
imported = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)
188185
for k in imported.signature_def.keys():
@@ -199,17 +196,17 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, signatures
199196
# TF1.12 changed the api
200197
get_signature_def = lambda meta_graph_def, k: meta_graph_def.signature_def[k]
201198

199+
input_names = []
200+
output_names = []
202201
for k in signatures:
203202
inputs_tensor_info = get_signature_def(imported, k).inputs
204203
for _, input_tensor in inputs_tensor_info.items():
205-
inputs[input_tensor.name] = sess.graph.get_tensor_by_name(input_tensor.name)
204+
if input_tensor.dtype != tf.dtypes.resource:
205+
input_names.append(input_tensor.name)
206206
outputs_tensor_info = get_signature_def(imported, k).outputs
207207
for _, output_tensor in outputs_tensor_info.items():
208-
outputs[output_tensor.name] = sess.graph.get_tensor_by_name(output_tensor.name)
209-
210-
input_names = list(inputs.keys())
211-
input_names = inputs_without_resource(sess, input_names)
212-
output_names = list(outputs.keys())
208+
if output_tensor.dtype != tf.dtypes.resource:
209+
output_names.append(output_tensor.name)
213210
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
214211
return frozen_graph, input_names, output_names
215212

@@ -226,8 +223,10 @@ def _from_saved_model_v2(model_path, input_names, output_names, signatures):
226223
signatures.append(k)
227224
for k in signatures:
228225
concrete_func = imported.signatures[k]
229-
input_names = [input_tensor.name for input_tensor in concrete_func.inputs]
230-
output_names = [output_tensor.name for output_tensor in concrete_func.outputs]
226+
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
227+
if input_tensor.dtype != tf.dtypes.resource]
228+
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
229+
if output_tensor.dtype != tf.dtypes.resource]
231230

232231
frozen_graph = from_function(concrete_func, input_names, output_names)
233232
return frozen_graph, input_names, output_names
@@ -270,25 +269,21 @@ def from_keras(model_path, input_names, output_names):
270269
concrete_func = function.get_concrete_function()
271270
# allow to pass inputs and outputs from caller if we don't want all of them
272271
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
273-
if input_names is None or input_tensor.name in input_names]
272+
if input_tensor.dtype != tf.dtypes.resource]
274273
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
275-
if output_names is None or output_tensor.name in output_names]
274+
if output_tensor.dtype != tf.dtypes.resource]
275+
276276
frozen_graph = from_function(concrete_func, input_names, output_names)
277277
else:
278278
# Handles Keras when Eager mode is disabled.
279279
_keras.backend.clear_session()
280280
_keras.backend.set_learning_phase(False)
281281
keras_model = _keras.models.load_model(model_path, custom_objects)
282282
# allow to pass inputs and outputs from caller if we don't want all of them
283-
if input_names:
284-
input_names = [i for i in keras_model.inputs if i in input_names]
285-
else:
286-
input_names = keras_model.inputs
287-
if output_names:
288-
output_names = [i for i in keras_model.outputs if i in output_names]
289-
else:
290-
output_names = keras_model.outputs
283+
input_names = keras_model.inputs
284+
output_names = keras_model.outputs
291285
sess = _keras.backend.get_session()
286+
input_names = inputs_without_resource(sess, input_names)
292287
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)
293288
tf_reset_default_graph()
294289
with tf_session() as sess:

0 commit comments

Comments
 (0)