Skip to content

Commit 92326e8

Browse files
authored
Merge pull request #53 from onnx/gs/onnx-1.2
add support for placeholderwithdefault, allow testing of checkpoints
2 parents bc0e33c + 6eb8c7f commit 92326e8

File tree

4 files changed

+50
-2
lines changed

4 files changed

+50
-2
lines changed

tests/run_pretrained_models.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import yaml
1818
from tensorflow.core.framework import graph_pb2
1919
from tf2onnx.tfonnx import process_tf_graph
20+
from tensorflow.python.framework.graph_util import convert_variables_to_constants
2021

2122
TMPPATH = tempfile.mkdtemp()
2223
PERFITER = 1000
@@ -77,11 +78,29 @@ def node_name(name):
7778
return name
7879

7980

81+
def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=True):
82+
"""Freezes the state of a session into a pruned computation graph."""
83+
output_names = [i.replace(":0", "") for i in output_names]
84+
graph = sess.graph
85+
with graph.as_default():
86+
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
87+
output_names = output_names or []
88+
output_names += [v.op.name for v in tf.global_variables()]
89+
input_graph_def = graph.as_graph_def()
90+
if clear_devices:
91+
for node in input_graph_def.node:
92+
node.device = ""
93+
frozen_graph = convert_variables_to_constants(sess, input_graph_def,
94+
output_names, freeze_var_names)
95+
return frozen_graph
96+
97+
8098
class Test(object):
8199
cache_dir = None
82100

83101
def __init__(self, url, local, make_input, input_names, output_names,
84-
disabled=False, more_inputs=None, rtol=0.01, atol=0., check_only_shape=False):
102+
disabled=False, more_inputs=None, rtol=0.01, atol=0.,
103+
check_only_shape=False, model_type="frozen"):
85104
self.url = url
86105
self.make_input = make_input
87106
self.local = local
@@ -95,6 +114,7 @@ def __init__(self, url, local, make_input, input_names, output_names,
95114
self.perf = None
96115
self.tf_runtime = 0
97116
self.onnx_runtime = 0
117+
self.model_type = model_type
98118

99119
def download_file(self):
100120
"""Download file from url."""
@@ -231,8 +251,18 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
231251
model_path = os.path.join(dir_name, self.local)
232252
else:
233253
model_path = self.local
254+
dir_name = os.path.dirname(self.local)
234255
print("\tdownloaded", model_path)
235256

257+
# if the input model is a checkpoint, convert it to a frozen model
258+
if self.model_type in ["checkpoint"]:
259+
saver = tf.train.import_meta_graph(model_path)
260+
with tf.Session() as sess:
261+
saver.restore(sess, model_path[:-5])
262+
frozen_graph = freeze_session(sess, output_names=self.output_names)
263+
tf.train.write_graph(frozen_graph, dir_name, "frozen.pb", as_text=False)
264+
model_path = os.path.join(dir_name, "frozen.pb")
265+
236266
inputs = self.make_input(self.input_names)
237267
if self.more_inputs:
238268
for k, v in self.more_inputs.items():
@@ -314,7 +344,7 @@ def tests_from_yaml(fname):
314344
input_func = v.get("input_get")
315345
input_func = _INPUT_FUNC_MAPPING[input_func]
316346
kwargs = {}
317-
for kw in ["rtol", "atol", "disabled", "more_inputs", "check_only_shape"]:
347+
for kw in ["rtol", "atol", "disabled", "more_inputs", "check_only_shape", "model_type"]:
318348
if v.get(kw) is not None:
319349
kwargs[kw] = v[kw]
320350

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,21 @@ def test_add(self):
362362
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
363363
self.assertAllClose(expected, actual)
364364

365+
def test_placeholder(self):
366+
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
367+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
368+
output = tf.identity(x, name=_TFOUTPUT)
369+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
370+
self.assertAllClose(expected, actual)
371+
372+
def test_placeholder_with_default(self):
373+
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
374+
y = tf.constant(x_val, name="y")
375+
x = tf.placeholder_with_default(y, x_val.shape, name=_TFINPUT)
376+
output = tf.identity(x, name=_TFOUTPUT)
377+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
378+
self.assertAllClose(expected, actual)
379+
365380
def test_add_bcast(self):
366381
x1_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
367382
x2_val = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float32).reshape((2, 2, 2))

tf2onnx/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def get_tensor_value(self):
151151

152152
def get_tensor(self):
153153
if not self.is_const():
154+
if self.type == "Identity":
155+
return self.inputs[0].get_tensor()
154156
raise ValueError("get tensor: {} must be Const".format(self.name))
155157
t = self.get_attr("value")
156158
if t:

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,7 @@ def minmax_op(ctx, node, name, args):
887887
"Pad": (pad_op, []),
888888
"Placeholder": (placeholder_op, []),
889889
"PlaceholderV2": (placeholder_op, []),
890+
"PlaceholderWithDefault": (placeholder_op, []),
890891
"Pow": (pow_op, []),
891892
"Prod": (reduce_op, ["ReduceProd"]),
892893
"RandomNormal": (direct_op, []),

0 commit comments

Comments
 (0)