Skip to content

Commit ea7404d

Browse files
committed
support for checkpoint and saved_model format
1 parent bd38f61 commit ea7404d

File tree

13 files changed

+168
-72
lines changed

13 files changed

+168
-72
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_checkpoint_path: "C:\\src\\tensorflow-onnx\\tests\\models\\regression\\checkpoint\\model"
2+
all_model_checkpoint_paths: "C:\\src\\tensorflow-onnx\\tests\\models\\regression\\checkpoint\\model"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
�[>�V�?
134 Bytes
Binary file not shown.
19.1 KB
Binary file not shown.
350 Bytes
Binary file not shown.
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
�[>�V�?
Binary file not shown.

tests/run_pretrained_models.py

Lines changed: 10 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@
1616
import traceback
1717
import zipfile
1818

19+
import PIL.Image
1920
import numpy as np
2021
import requests
2122
import six
2223
import tensorflow as tf
23-
from tensorflow.core.framework import graph_pb2
24-
from tensorflow.python.framework.graph_util import convert_variables_to_constants
2524
# contrib ops are registered only when the module is imported, the following import statement is needed,
2625
# otherwise tf runtime error will show up when the tf model is restored from pb file because of un-registered ops.
2726
import tensorflow.contrib.rnn # pylint: disable=unused-import
2827
import yaml
29-
import PIL.Image
28+
from tensorflow.core.framework import graph_pb2
3029

3130
import tf2onnx
31+
from tf2onnx import loader
3232
from tf2onnx import utils
3333
from tf2onnx.graph import GraphUtil
3434
from tf2onnx.tfonnx import process_tf_graph
@@ -74,23 +74,6 @@ def get_ramp(shape):
7474
}
7575

7676

77-
def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=True):
78-
"""Freezes the state of a session into a pruned computation graph."""
79-
output_names = [i.replace(":0", "") for i in output_names]
80-
graph = sess.graph
81-
with graph.as_default():
82-
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
83-
output_names = output_names or []
84-
output_names += [v.op.name for v in tf.global_variables()]
85-
input_graph_def = graph.as_graph_def()
86-
if clear_devices:
87-
for node in input_graph_def.node:
88-
node.device = ""
89-
frozen_graph = convert_variables_to_constants(sess, input_graph_def,
90-
output_names, freeze_var_names)
91-
return frozen_graph
92-
93-
9477
class Test(object):
9578
"""Main Test class."""
9679

@@ -236,45 +219,15 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
236219
dir_name = os.path.dirname(self.local)
237220
print("\tdownloaded", model_path)
238221

222+
inputs = list(self.input_names.keys())
223+
outputs = self.output_names
239224
if self.model_type in ["checkpoint"]:
240-
#
241-
# if the input model is a checkpoint, convert it to a frozen model
242-
saver = tf.train.import_meta_graph(model_path)
243-
with tf.Session() as sess:
244-
saver.restore(sess, model_path[:-5])
245-
frozen_graph = freeze_session(sess, output_names=self.output_names)
246-
tf.train.write_graph(frozen_graph, dir_name, "frozen.pb", as_text=False)
247-
model_path = os.path.join(dir_name, "frozen.pb")
225+
graph_def, inputs, outputs = loader.from_checkpoint(model_path, inputs, outputs)
248226
elif self.model_type in ["saved_model"]:
249-
try:
250-
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
251-
get_signature_def = lambda meta_graph_def, k: \
252-
signature_def_utils.get_signature_def_by_key(meta_graph_def, k)
253-
except ImportError:
254-
# TF1.12 changed the api
255-
get_signature_def = lambda meta_graph_def, k: meta_graph_def.signature_def[k]
256-
257-
# saved_model format - convert to checkpoint
258-
with tf.Session() as sess:
259-
meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)
260-
inputs = {}
261-
outputs = {}
262-
for k in meta_graph_def.signature_def.keys():
263-
inputs_tensor_info = get_signature_def(meta_graph_def, k).inputs
264-
for _, input_tensor in sorted(inputs_tensor_info.items()):
265-
inputs[input_tensor.name] = sess.graph.get_tensor_by_name(input_tensor.name)
266-
outputs_tensor_info = get_signature_def(meta_graph_def, k).outputs
267-
for _, output_tensor in sorted(outputs_tensor_info.items()):
268-
outputs[output_tensor.name] = sess.graph.get_tensor_by_name(output_tensor.name)
269-
# freeze uses the node name derived from output:0 so only pass in output:0;
270-
# it will provide all outputs of that node.
271-
for o in list(outputs.keys()):
272-
if not o.endswith(":0"):
273-
del outputs[o]
274-
frozen_graph = freeze_session(sess, output_names=list(outputs.keys()))
275-
tf.train.write_graph(frozen_graph, dir_name, "frozen.pb", as_text=False)
276-
model_path = os.path.join(dir_name, "frozen.pb")
277-
227+
graph_def, inputs, outputs = loader.from_saved_model(model_path, inputs, outputs)
228+
else:
229+
graph_def, inputs, outputs = loader.from_graphdef(model_path, inputs, outputs)
230+
278231
# create the input data
279232
inputs = {}
280233
for k, v in self.input_names.items():
@@ -285,10 +238,6 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
285238
if self.more_inputs:
286239
for k, v in self.more_inputs.items():
287240
inputs[k] = v
288-
tf.reset_default_graph()
289-
graph_def = graph_pb2.GraphDef()
290-
with open(model_path, "rb") as f:
291-
graph_def.ParseFromString(f.read())
292241

293242
graph_def = tf2onnx.tfonnx.tf_optimize(inputs.keys(), self.output_names, graph_def, fold_const)
294243
shape_override = {}

tests/run_pretrained_models.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,32 @@
11
#
22
# simple models for basic functional test
33
#
4+
regression-graphdef:
5+
model: tests/models/regression/graphdef/frozen.pb
6+
input_get: get_ramp
7+
inputs:
8+
"X:0": [1]
9+
outputs:
10+
- pred:0
11+
12+
regression-checkpoint:
13+
model: tests/models/regression/checkpoint/model.meta
14+
model_type: checkpoint
15+
input_get: get_ramp
16+
inputs:
17+
"X:0": [1]
18+
outputs:
19+
- pred:0
20+
21+
regression-saved-model:
22+
model: tests/models/regression/saved_model
23+
model_type: saved_model
24+
input_get: get_ramp
25+
inputs:
26+
"X:0": [1]
27+
outputs:
28+
- pred:0
29+
430
benchtf-fc:
531
model: tests/models/fc-layers/frozen.pb
632
input_get: get_ramp

0 commit comments

Comments
 (0)