Skip to content

Commit e28da28

Browse files
Merge pull request #1107 from onnx/tom/ExtendRunPretrainedModels
Update run_pretrained_models.py to support large models
2 parents b1eccce + efc4c2b commit e28da28

File tree

4 files changed

+136
-41
lines changed

4 files changed

+136
-41
lines changed

tests/run_pretrained_models.py

Lines changed: 112 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from tf2onnx import tf_loader, logging, optimizer, utils, tf_utils
4242
from tf2onnx.tfonnx import process_tf_graph
4343
from tf2onnx.tf_loader import tf_session, tf_reset_default_graph
44+
from tf2onnx.graph import ExternalTensorStorage
4445

4546
logger = logging.getLogger("run_pretrained")
4647

@@ -102,16 +103,20 @@ class Test(object):
102103
cache_dir = None
103104
target = []
104105

105-
def __init__(self, url, local, make_input, input_names, output_names,
106+
def __init__(self, url, local, input_func, input_names, output_names,
106107
disabled=False, rtol=0.01, atol=1e-6,
107108
check_only_shape=False, model_type="frozen", force_input_shape=False,
108-
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None):
109+
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None,
110+
skip_conversion=False, converted_model=None, signature_def=None, concrete_function=None,
111+
large_model=False, structured_outputs=None):
109112
self.url = url
110-
self.make_input = make_input
113+
self.input_func = input_func
111114
self.local = local
112115
self.input_names = input_names
113116
self.output_names = output_names
114117
self.disabled = disabled
118+
self.large_model = large_model
119+
self.structured_outputs = structured_outputs # Needed to determine output order for tf_function
115120
self.rtol = rtol
116121
self.atol = atol
117122
self.check_only_shape = check_only_shape
@@ -122,8 +127,18 @@ def __init__(self, url, local, make_input, input_names, output_names,
122127
self.tag = tag
123128
self.force_input_shape = force_input_shape
124129
self.skip_tensorflow = skip_tensorflow
130+
self.skip_conversion = skip_conversion
131+
self.converted_model = converted_model
125132
self.opset_constraints = opset_constraints
126133
self.tf_min_version = tf_min_version
134+
self.signatures = [signature_def] if signature_def else None
135+
self.concrete_function = concrete_function
136+
137+
def make_input(self, v):
138+
"""Allows each input to specify its own function while defaulting to the input_get function"""
139+
if isinstance(v, dict):
140+
return _INPUT_FUNC_MAPPING[v["input_get"]](v["shape"])
141+
return self.input_func(v)
127142

128143
def download_model(self):
129144
"""Download model from url."""
@@ -149,7 +164,7 @@ def download_model(self):
149164
if not os.path.exists(fpath):
150165
utils.get_url(url, fpath)
151166
model_path = os.path.join(dir_name, self.local)
152-
if not os.path.exists(model_path):
167+
if not os.path.exists(model_path) or self.local == ".":
153168
if ftype == 'tgz':
154169
tar = tarfile.open(fpath)
155170
tar.extractall(dir_name)
@@ -179,19 +194,23 @@ def run_tensorflow(self, sess, inputs):
179194
for k, v in inputs.items():
180195
k = sess.graph.get_tensor_by_name(k)
181196
feed_dict[k] = v
197+
logger.info("Running TF")
182198
result = sess.run(self.output_names, feed_dict=feed_dict)
183199
if self.perf:
200+
logger.info("Running TF perf")
184201
start = time.time()
185202
for _ in range(PERFITER):
186203
_ = sess.run(self.output_names, feed_dict=feed_dict)
187204
self.tf_runtime = time.time() - start
188205
return result
189206

190-
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None):
207+
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None,
208+
const_node_values=None):
191209
"""Convert graph to tensorflow."""
192210
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset,
193211
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
194-
input_names=input_names, output_names=self.output_names)
212+
input_names=input_names, output_names=self.output_names,
213+
const_node_values=const_node_values)
195214

196215
def run_caffe2(self, name, model_proto, inputs):
197216
"""Run test again caffe2 backend."""
@@ -205,11 +224,12 @@ def run_caffe2(self, name, model_proto, inputs):
205224
self.onnx_runtime = time.time() - start
206225
return results
207226

208-
def run_onnxruntime(self, name, model_proto, inputs):
227+
def run_onnxruntime(self, name, model_proto, inputs, external_tensor_storage=None):
209228
"""Run test against onnxruntime backend."""
210229
import onnxruntime as rt
211230
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=True,
212-
as_text=utils.is_debug_mode())
231+
as_text=utils.is_debug_mode(),
232+
external_tensor_storage=external_tensor_storage)
213233
logger.info("Model saved to %s", model_path)
214234
m = rt.InferenceSession(model_path)
215235
results = m.run(self.output_names, inputs)
@@ -221,10 +241,14 @@ def run_onnxruntime(self, name, model_proto, inputs):
221241
return results
222242

223243
@staticmethod
224-
def create_onnx_file(name, model_proto, inputs, outdir):
244+
def create_onnx_file(name, model_proto, inputs, outdir, external_tensor_storage=None):
225245
os.makedirs(outdir, exist_ok=True)
226-
model_path = os.path.join(outdir, name + ".onnx")
227-
utils.save_protobuf(model_path, model_proto)
246+
if external_tensor_storage is None:
247+
model_path = os.path.join(outdir, name + ".onnx")
248+
utils.save_protobuf(model_path, model_proto)
249+
else:
250+
model_path = os.path.join(outdir, name + ".zip")
251+
utils.save_onnx_zip(model_path, model_proto, external_tensor_storage)
228252
logger.info("Created %s", model_path)
229253

230254
def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_opset=None,
@@ -236,7 +260,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
236260
if self.url:
237261
_, dir_name = self.download_model()
238262
logger.info("Downloaded to %s", dir_name)
239-
model_path = os.path.join(dir_name, self.local)
263+
model_path = os.path.join(dir_name, self.local) if self.local != "." else dir_name
240264
else:
241265
model_path = self.local
242266

@@ -246,13 +270,15 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
246270
if self.model_type in ["checkpoint"]:
247271
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
248272
elif self.model_type in ["saved_model"]:
249-
try:
250-
res = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
251-
except OSError:
252-
model_path = dir_name
253-
logger.info("Load model(2) from %r", model_path)
254-
res = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
255-
graph_def, input_names, outputs = res[:3]
273+
loaded = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag, self.signatures,
274+
self.concrete_function, self.large_model,
275+
return_concrete_func=self.large_model)
276+
if self.large_model:
277+
# Must maintain ref to imported since concrete_func uses weak refs
278+
# pylint: disable=unused-variable
279+
graph_def, input_names, outputs, concrete_func, imported = loaded
280+
else:
281+
graph_def, input_names, outputs = loaded
256282
elif self.model_type in ["keras"]:
257283
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
258284
else:
@@ -261,9 +287,34 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
261287
if utils.is_debug_mode():
262288
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
263289

290+
if self.large_model:
291+
inputs = {}
292+
for k in input_names:
293+
v = self.input_names[k]
294+
inputs[k.split(":")[0]] = tf.constant(self.make_input(v))
295+
tf_func = tf.function(concrete_func)
296+
logger.info("Running TF")
297+
tf_results_d = tf_func(**inputs)
298+
if self.structured_outputs is None:
299+
tf_results = list(tf_results_d.values())
300+
else:
301+
tf_results = [tf_results_d[output] for output in self.structured_outputs]
302+
if self.perf:
303+
logger.info("Running TF perf")
304+
start = time.time()
305+
for _ in range(PERFITER):
306+
_ = concrete_func(**inputs)
307+
self.tf_runtime = time.time() - start
308+
logger.info("TensorFlow OK")
309+
264310
inputs = {}
265311
shape_override = {}
266312
tf_reset_default_graph()
313+
314+
from tf2onnx.tf_utils import compress_graph_def
315+
const_node_values = None
316+
if self.large_model:
317+
const_node_values = compress_graph_def(graph_def)
267318
g = tf.import_graph_def(graph_def, name='')
268319
# with tf_session(config=tf.ConfigProto(allow_soft_placement=True), graph=g) as sess:
269320
with tf_session(graph=g) as sess:
@@ -288,30 +339,50 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
288339
# run the model with tensorflow
289340
if self.skip_tensorflow:
290341
logger.info("TensorFlow SKIPPED")
291-
else:
342+
elif not self.large_model:
292343
tf_results = self.run_tensorflow(sess, inputs)
293344
logger.info("TensorFlow OK")
294345

295346
model_proto = None
296-
try:
297-
# convert model to onnx
298-
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
299-
shape_override=shape_override, input_names=inputs.keys())
300-
onnx_graph = optimizer.optimize_graph(onnx_graph)
301-
model_proto = onnx_graph.make_model("converted from tf2onnx")
302-
logger.info("To_ONNX, OK")
303-
if onnx_file:
304-
self.create_onnx_file(name, model_proto, inputs, onnx_file)
305-
except Exception:
306-
logger.error("To_ONNX FAIL", exc_info=1)
307-
return False
347+
if self.skip_conversion:
348+
if self.large_model:
349+
external_tensor_storage = ExternalTensorStorage()
350+
model_proto = utils.model_proto_from_zip(self.converted_model, external_tensor_storage)
351+
else:
352+
external_tensor_storage = None
353+
model_proto = utils.model_proto_from_file(self.converted_model)
354+
logger.info("ONNX loaded from file")
355+
else:
356+
try:
357+
# convert model to onnx
358+
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
359+
shape_override=shape_override, input_names=inputs.keys(),
360+
const_node_values=const_node_values)
361+
onnx_graph = optimizer.optimize_graph(onnx_graph)
362+
print("ONNX", onnx_graph.dump_node_statistics())
363+
external_tensor_storage = ExternalTensorStorage() if self.large_model else None
364+
model_proto = onnx_graph.make_model("converted from tf2onnx",
365+
external_tensor_storage=external_tensor_storage)
366+
logger.info("To_ONNX, OK")
367+
if onnx_file:
368+
self.create_onnx_file(name, model_proto, inputs, onnx_file, external_tensor_storage)
369+
if self.converted_model:
370+
if self.large_model:
371+
utils.save_onnx_zip(self.converted_model, model_proto, external_tensor_storage)
372+
else:
373+
utils.save_protobuf(self.converted_model, model_proto)
374+
logger.info("Created %s", self.converted_model)
375+
376+
except Exception:
377+
logger.error("To_ONNX FAIL", exc_info=1)
378+
return False
308379

309380
try:
310381
onnx_results = None
311382
if backend == "caffe2":
312383
onnx_results = self.run_caffe2(name, model_proto, inputs)
313384
elif backend == "onnxruntime":
314-
onnx_results = self.run_onnxruntime(name, model_proto, inputs)
385+
onnx_results = self.run_onnxruntime(name, model_proto, inputs, external_tensor_storage)
315386
else:
316387
raise ValueError("unknown backend")
317388
logger.info("Run_ONNX OK")
@@ -390,6 +461,7 @@ def get_args():
390461
parser.add_argument("--list", help="list tests", action="store_true")
391462
parser.add_argument("--onnx-file", help="create onnx file in directory")
392463
parser.add_argument("--perf", help="capture performance numbers")
464+
parser.add_argument("--perfiter", type=int, default=PERFITER, help="number of inferences for perf testing")
393465
parser.add_argument("--fold_const", help="enable tf constant_folding transformation before conversion",
394466
action="store_true")
395467
parser.add_argument("--include-disabled", help="include disabled tests", action="store_true")
@@ -447,8 +519,9 @@ def load_tests_from_yaml(path):
447519
opset_constraints.append(c)
448520

449521
kwargs = {}
450-
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type",
451-
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag"]:
522+
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type", "concrete_function",
523+
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag", "skip_conversion",
524+
"converted_model", "signature_def", "large_model", "structured_outputs"]:
452525
if settings.get(kw) is not None:
453526
kwargs[kw] = settings[kw]
454527

@@ -459,6 +532,7 @@ def load_tests_from_yaml(path):
459532

460533

461534
def main():
535+
global PERFITER
462536
args = get_args()
463537
logging.basicConfig(level=logging.get_verbosity_level(args.verbose))
464538
if args.debug:
@@ -477,6 +551,7 @@ def main():
477551

478552
failed = 0
479553
count = 0
554+
PERFITER = args.perfiter
480555
for test in test_keys:
481556
logger.info("===================================")
482557

@@ -520,7 +595,8 @@ def main():
520595
for test in test_keys:
521596
t = tests[test]
522597
if t.perf:
523-
f.write("{},{},{}\n".format(test, t.tf_runtime, t.onnx_runtime))
598+
# Report perf in ms per inference
599+
f.write("{},{},{}\n".format(test, t.tf_runtime * 1000 / PERFITER, t.onnx_runtime * 1000 / PERFITER))
524600
return failed
525601

526602

tests/run_pretrained_models.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ benchtf-gru:
119119
esrgan-tf2:
120120
# url: https://tfhub.dev/captain-pool/esrgan-tf2/1/esrgan-tf2_1.tar.gz
121121
url: https://github.com/captain-pool/GSOC/releases/download/1.0.0/esrgan.tar.gz
122-
model: ersgan
122+
model: "."
123123
model_type: saved_model
124124
input_get: get_beach
125125
opset_constraints:

tf2onnx/tf_loader.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,18 +316,21 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
316316
raise ValueError(err_large_model)
317317
raise e
318318

319-
return frozen_graph, inputs, outputs
319+
return frozen_graph, inputs, outputs, concrete_func, imported
320320

321321

322322
def from_saved_model(model_path, input_names, output_names, tag=None,
323-
signatures=None, concrete_function=None, large_model=False):
323+
signatures=None, concrete_function=None, large_model=False, return_concrete_func=False):
324324
"""Load tensorflow graph from saved_model."""
325325
if signatures is None:
326326
signatures = []
327327
tf_reset_default_graph()
328328
if is_tf2():
329-
frozen_graph, input_names, output_names = \
329+
frozen_graph, input_names, output_names, concrete_func, imported = \
330330
_from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function, large_model)
331+
if return_concrete_func:
332+
tf_reset_default_graph()
333+
return frozen_graph, input_names, output_names, concrete_func, imported
331334
else:
332335
with tf_session() as sess:
333336
frozen_graph, input_names, output_names = \

tf2onnx/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from urllib3.util.retry import Retry
2121
import numpy as np
2222
from google.protobuf import text_format
23-
from onnx import helper, onnx_pb, defs, numpy_helper, __version__
23+
from onnx import helper, onnx_pb, defs, numpy_helper, ModelProto, __version__
2424

2525
from . import constants
2626

@@ -269,6 +269,22 @@ def save_protobuf(path, message, as_text=False):
269269
with open(path, "wb") as f:
270270
f.write(message.SerializeToString())
271271

272+
def model_proto_from_file(model_path):
273+
model_proto = ModelProto()
274+
with open(model_path, "rb") as f:
275+
model_proto.ParseFromString(f.read())
276+
return model_proto
277+
278+
def model_proto_from_zip(zip_path, external_tensor_storage):
279+
model_proto = ModelProto()
280+
with zipfile.ZipFile(zip_path, 'r') as z:
281+
for n in z.namelist():
282+
f = z.open(n)
283+
if n.endswith(".onnx"):
284+
model_proto.ParseFromString(f.read())
285+
else:
286+
external_tensor_storage.name_to_tensor_data[n] = f.read()
287+
return model_proto
272288

273289
def is_list_or_tuple(obj):
274290
return isinstance(obj, (list, tuple))

0 commit comments

Comments
 (0)