Skip to content

Commit 101fbc4

Browse files
Add pretrained model tests for tfjs (#1622)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent a8d5a3c commit 101fbc4

File tree

7 files changed

+125
-22
lines changed

7 files changed

+125
-22
lines changed

ci_build/azure_pipelines/pretrained_model_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ jobs:
66
python_versions: ['3.7']
77
tf_versions: ['2.4.1']
88
skip_tflite_tests: 'False'
9+
skip_tfjs_tests: 'False'
910
skip_tf_tests: 'True'
1011
job:
1112
steps:

ci_build/azure_pipelines/templates/pretrained_model_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ steps:
66
status=0
77
# TODO: fix unity model path
88
# python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --config tests/unity.yaml || status=$?
9-
python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --skip_tf_tests $CI_SKIP_TF_TESTS --skip_tflite_tests $CI_SKIP_TFLITE_TESTS --config tests/run_pretrained_models.yaml || status=$?
9+
python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --skip_tf_tests $CI_SKIP_TF_TESTS --skip_tflite_tests $CI_SKIP_TFLITE_TESTS --skip_tfjs_tests $CI_SKIP_TFJS_TESTS --config tests/run_pretrained_models.yaml || status=$?
1010
exit $status
1111
displayName: 'Test Pre-trained Model'

tests/run_pretrained_models.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from tf2onnx.tfonnx import process_tf_graph
4646
from tf2onnx.tf_loader import tf_session, tf_reset_default_graph
4747
from tf2onnx.graph import ExternalTensorStorage
48+
from tfjs_runner import run_tfjs
4849

4950
logger = logging.getLogger("run_pretrained")
5051

@@ -251,6 +252,10 @@ def download_model(self):
251252
elif self.model_type == 'tflite':
252253
fname = self.local
253254
dir_name = fname.replace(".tflite", "") + "_dir"
255+
elif self.model_type == 'tfjs':
256+
ftype = 'tgz'
257+
fname = 'model.tar.gz'
258+
dir_name = "_".join(url.split("/")[5:-3]) + "_dir"
254259
dir_name = os.path.join(cache_dir, dir_name)
255260
os.makedirs(dir_name, exist_ok=True)
256261
fpath = os.path.join(dir_name, fname)
@@ -303,7 +308,8 @@ def run_tensorflow(self, sess, inputs):
303308
return result
304309

305310
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None,
306-
const_node_values=None, initialized_tables=None, tflite_path=None, tensors_to_rename=None):
311+
const_node_values=None, initialized_tables=None, tflite_path=None, tensors_to_rename=None,
312+
tfjs_path=None):
307313
"""Convert graph to tensorflow."""
308314
if extra_opset is None:
309315
extra_opset = []
@@ -314,7 +320,7 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i
314320
input_names=input_names, output_names=self.output_names,
315321
const_node_values=const_node_values, initialized_tables=initialized_tables,
316322
tflite_path=tflite_path, dequantize=self.dequantize,
317-
tensors_to_rename=tensors_to_rename)
323+
tensors_to_rename=tensors_to_rename, tfjs_path=tfjs_path)
318324

319325
def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_storage=None):
320326
"""Run test against onnxruntime backend."""
@@ -375,6 +381,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
375381
initialized_tables = {}
376382
outputs = self.output_names
377383
tflite_path = None
384+
tfjs_path = None
378385
to_rename = {}
379386
if self.model_type in ["checkpoint"]:
380387
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
@@ -394,6 +401,9 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
394401
elif self.model_type in ["tflite"]:
395402
tflite_path = model_path
396403
graph_def = None
404+
elif self.model_type in ["tfjs"]:
405+
tfjs_path = model_path
406+
graph_def = None
397407
else:
398408
graph_def, input_names, outputs = tf_loader.from_graphdef(model_path, input_names, outputs)
399409

@@ -434,6 +444,16 @@ def run_tflite():
434444
logger.info("TFLite perf {:.2f}ms/inference, n={}".format(self.tf_runtime, n))
435445
logger.info("TFLite OK")
436446

447+
if tfjs_path is not None:
448+
inputs = {}
449+
for k in input_names:
450+
v = self.input_names[k]
451+
inputs[k] = self.make_input(v)
452+
if not self.skip_tensorflow:
453+
logger.info("Running TFJS")
454+
tf_results = run_tfjs(tfjs_path, inputs, dir_name)
455+
logger.info("TFJS OK")
456+
437457
if not self.run_tf_frozen:
438458
inputs = {}
439459
for k in input_names:
@@ -465,7 +485,6 @@ def run_tflite():
465485
logger.info("TF perf {:.2f}ms/inference, n={}".format(self.tf_runtime, n))
466486
logger.info("TensorFlow OK")
467487

468-
shape_override = {}
469488
const_node_values = None
470489
tf_graph = None
471490

@@ -497,10 +516,6 @@ def run_tflite():
497516
else:
498517
inputs[k] = self.make_input(v).astype(expected_dtype)
499518

500-
if self.force_input_shape:
501-
for k, v in inputs.items():
502-
shape_override[k] = list(v.shape)
503-
504519
# run the model with tensorflow
505520
if self.skip_tensorflow:
506521
logger.info("TensorFlow SKIPPED")
@@ -526,11 +541,15 @@ def run_tflite():
526541
else:
527542
try:
528543
# convert model to onnx
544+
if self.force_input_shape:
545+
shape_override = {k: list(v.shape) for k, v in inputs.items()}
546+
else:
547+
shape_override = None
529548
onnx_graph = self.to_onnx(tf_graph, opset=opset, extra_opset=extra_opset,
530549
shape_override=shape_override, input_names=inputs.keys(),
531550
const_node_values=const_node_values,
532551
initialized_tables=initialized_tables, tflite_path=tflite_path,
533-
tensors_to_rename=to_rename)
552+
tensors_to_rename=to_rename, tfjs_path=tfjs_path)
534553
onnx_graph = optimizer.optimize_graph(onnx_graph)
535554
print("ONNX", onnx_graph.dump_node_statistics())
536555
external_tensor_storage = ExternalTensorStorage() if self.large_model else None
@@ -636,6 +655,7 @@ def get_args():
636655
help="extra opset with format like domain:version, e.g. com.microsoft:1")
637656
parser.add_argument("--skip_tf_tests", help="skip non-tflite tests", default="False")
638657
parser.add_argument("--skip_tflite_tests", help="skip tflite tests", default="False")
658+
parser.add_argument("--skip_tfjs_tests", help="skip tfjs tests", default="False")
639659
parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count")
640660
parser.add_argument("--debug", help="debug mode", action="store_true")
641661
parser.add_argument("--list", help="list tests", action="store_true")
@@ -647,6 +667,7 @@ def get_args():
647667
args.target = args.target.split(",")
648668
args.skip_tf_tests = args.skip_tf_tests.upper() == "TRUE"
649669
args.skip_tflite_tests = args.skip_tflite_tests.upper() == "TRUE"
670+
args.skip_tfjs_tests = args.skip_tfjs_tests.upper() == "TRUE"
650671
if args.extra_opset:
651672
tokens = args.extra_opset.split(':')
652673
if len(tokens) != 2:
@@ -739,11 +760,14 @@ def main():
739760
logger.info("Skip %s: disabled", test)
740761
continue
741762

763+
if args.skip_tfjs_tests and t.model_type == "tfjs":
764+
logger.info("Skip %s: tfjs test", test)
765+
continue
742766
if args.skip_tflite_tests and t.model_type == "tflite":
743767
logger.info("Skip %s: tflite test", test)
744768
continue
745-
if args.skip_tf_tests and t.model_type != "tflite":
746-
logger.info("Skip %s: not tflite test", test)
769+
if args.skip_tf_tests and t.model_type not in ["tflite", "tfjs"]:
770+
logger.info("Skip %s: tf test", test)
747771
continue
748772

749773
condition, reason = t.check_opset_constraints(args.opset, args.extra_opset)

tests/run_pretrained_models.yaml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,53 @@ melgan_tflite: # TFLite model with FlexOps and rank-3 transposes
586586
- Identity
587587
rtol: 0.02
588588
atol: 0.0005
589+
590+
handdetector_tfjs:
591+
tf_min_version: 2.1
592+
disabled: false
593+
url: https://tfhub.dev/tensorflow/tfjs-model/handdetector/1/default/1?tfjs-format=compressed
594+
model: "model.json"
595+
model_type: tfjs
596+
input_get: get_beach
597+
inputs:
598+
"input:0": [1, 256, 256, 3]
599+
outputs:
600+
- Identity:0
601+
atol: 0.0005
602+
603+
posenet_mobilenet_float_100_tfjs:
604+
tf_min_version: 2.1
605+
disabled: false
606+
url: https://tfhub.dev/tensorflow/tfjs-model/posenet/mobilenet/float/100/1/default/1?tfjs-format=compressed
607+
model: "model-stride8.json"
608+
model_type: tfjs
609+
input_get: get_beach
610+
force_input_shape: True # ORT doesn't implement autopadding for convs with dilations
611+
inputs:
612+
"sub_2:0": [1, 256, 256, 3]
613+
outputs:
614+
- MobilenetV1/offset_2/BiasAdd:0
615+
- MobilenetV1/heatmap_2/BiasAdd:0
616+
- MobilenetV1/displacement_fwd_2/BiasAdd:0
617+
- MobilenetV1/displacement_bwd_2/BiasAdd:0
618+
rtol: 0.02
619+
atol: 0.0005
620+
621+
posenet_mobilenet_quantized_2_075_tfjs:
622+
tf_min_version: 2.1
623+
disabled: false
624+
url: https://tfhub.dev/tensorflow/tfjs-model/posenet/mobilenet/quantized/2/075/1/default/1?tfjs-format=compressed
625+
model: "model-stride16.json"
626+
model_type: tfjs
627+
input_get: get_beach
628+
force_input_shape: True # ORT doesn't implement autopadding for convs with dilations
629+
inputs:
630+
"sub_2:0": [1, 256, 256, 3]
631+
outputs:
632+
- MobilenetV1/offset_2/BiasAdd:0
633+
- MobilenetV1/heatmap_2/BiasAdd:0
634+
- MobilenetV1/displacement_fwd_2/BiasAdd:0
635+
- MobilenetV1/displacement_bwd_2/BiasAdd:0
636+
rtol: 0.1
637+
ptol: 0.2
638+
atol: 0.005

tests/run_tfjs.js

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*/
99

1010
const tf = require('@tensorflow/tfjs');
11+
const zlib = require("zlib");
1112

1213
const fs = require('fs');
1314
const http = require('http');
@@ -48,16 +49,30 @@ if (process.argv[2] == '--test') {
4849
const modelDir = path.dirname(modelPath);
4950
const modelName = path.basename(modelPath);
5051

52+
const fd = fs.openSync(modelPath, 'r');
53+
const buffer = Buffer.alloc(2);
54+
fs.readSync(fd, buffer, 0, 2);
55+
fs.closeSync(fd);
56+
// Check for gzip magic number
57+
const needsUnzip = buffer[0] == 31 && buffer[1] == 139
58+
5159
// tf.loadGraphModel expects a url not a local file, so we serve it on localhost
5260
http.createServer(function (req, res) {
53-
fs.readFile(modelDir + req.url, function (err, data) {
61+
const callback = function (err, data) {
5462
if (err) {
5563
res.writeHead(404);
5664
res.end(JSON.stringify(err));
5765
return;
5866
}
5967
res.writeHead(200);
6068
res.end(data);
69+
}
70+
fs.readFile(modelDir + req.url, function (err, data) {
71+
if (err || !needsUnzip) {
72+
callback(err, data);
73+
} else {
74+
zlib.gunzip(data, callback);
75+
}
6176
});
6277
}).listen(8080);
6378

@@ -140,4 +155,4 @@ async function main() {
140155
fs.writeFileSync(outputPath, outputString, 'utf8');
141156
}
142157

143-
main().then(() => exit(0)).catch((err) => { console.error(err); exit(1) })
158+
main().then(() => exit(0)).catch((err) => { console.error(err); exit(1) })

tf2onnx/tfjs_utils.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def read_model_json(model_path):
206206
return model, zip_compressed
207207

208208

209-
def graphs_from_tfjs(model_path, input_names=None, output_names=None, ignore_default=None, use_default=None):
209+
def graphs_from_tfjs(model_path, input_names=None, output_names=None, shape_override=None,
210+
ignore_default=None, use_default=None):
210211
"""Given the path to a model.json file, parses the model into onnx graphs and returns the main graph and a
211212
topologically sorted list of subgraphs."""
212213
model, zip_compressed = read_model_json(model_path)
@@ -236,11 +237,13 @@ def graphs_from_tfjs(model_path, input_names=None, output_names=None, ignore_def
236237
if output_names is None and 'signature' in model:
237238
output_names = list(model['signature']['outputs'].keys())
238239

239-
main_g = read_tfjs_graph(topology['node'], weights, None, input_names, output_names, ignore_default, use_default)
240+
main_g = read_tfjs_graph(topology['node'], weights, None, input_names, output_names, shape_override,
241+
ignore_default, use_default)
240242
subgraphs = []
241243
funcs = sort_tfjs_functions(topology.get('library', {}).get('function', []))
242244
for func in funcs:
243-
sub_g = read_tfjs_graph(func.get('nodeDef', []), weights, func, None, None, ignore_default, use_default)
245+
sub_g = read_tfjs_graph(func.get('nodeDef', []), weights, func, None, None, shape_override,
246+
ignore_default, use_default)
244247
subgraphs.append(sub_g)
245248

246249
return main_g, subgraphs
@@ -259,7 +262,7 @@ def read_tfjs_weight(weight, weights_data, offset):
259262
if 'quantization' in weight:
260263
q_info = weight['quantization']
261264
q_dtype = np.dtype(q_info['dtype'])
262-
np_arr = np.frombuffer(weights_data, dtype=q_dtype, count=count, offset=i)
265+
np_arr = np.frombuffer(weights_data, dtype=q_dtype, count=count, offset=offset)
263266
num_bytes = np_arr.nbytes
264267
np_arr = np_arr.astype(np_dtype) * q_info['scale'] + q_info['min']
265268
else:
@@ -303,18 +306,27 @@ def read_tfjs_function(func):
303306
return tf_dtypes, output_shapes, inputs, outputs, name
304307

305308

306-
def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=None,
309+
def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=None, shape_override=None,
307310
ignore_default=None, use_default=None):
308311
"""Creates an onnx graph from the provided tfjs nodes"""
312+
if shape_override is None:
313+
shape_override = {}
309314
onnx_nodes = []
310315
output_shapes = {}
311316
tf_dtypes = {}
312317
op_info = {}
313318
graph_name = 'tfjs_model'
314319
func_name = None
315320

321+
def update_shapes(new_shapes):
322+
if isinstance(new_shapes, dict):
323+
new_shapes = new_shapes.items()
324+
for k, v in new_shapes:
325+
output_shapes[k] = shape_override.get(k, v)
326+
316327
if func is not None:
317-
tf_dtypes, output_shapes, graph_inputs, graph_outputs, func_name = read_tfjs_function(func)
328+
tf_dtypes, fn_input_shapes, graph_inputs, graph_outputs, func_name = read_tfjs_function(func)
329+
update_shapes(fn_input_shapes)
318330
graph_name = func_name
319331
for inp in graph_inputs:
320332
onnx_nodes.append(helper.make_node("Placeholder", [], outputs=[inp], name=inp))
@@ -338,7 +350,7 @@ def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=
338350
onnx_tensor = numpy_helper.from_array(np_arr.astype(np_dtype), out_name)
339351
onnx_node = helper.make_node("Const", [], outputs=[out_name], name=node_name, value=onnx_tensor)
340352
onnx_nodes.append(onnx_node)
341-
output_shapes[out_name] = list(np_arr.shape)
353+
output_shapes[out_name] = shape_override.get(out_name, list(np_arr.shape))
342354
tf_dtypes[out_name] = tf_dtype
343355
op_info[node_name] = (op_type, {'dtype': tf_dtypes[out_name]})
344356
continue
@@ -365,7 +377,7 @@ def read_tfjs_graph(nodes, weights, func=None, graph_inputs=None, graph_outputs=
365377

366378
output_names = [node_name + ":" + str(i) for i in range(len(out_dtypes))]
367379
tf_dtypes.update(zip(output_names, out_dtypes))
368-
output_shapes.update(zip(output_names, out_shapes))
380+
update_shapes(zip(output_names, out_shapes))
369381
unused_outputs.update(output_names)
370382

371383
if op_type == "PlaceholderWithDefault":

tf2onnx/tfonnx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
427427
main_g, subgraphs = graphs_from_tflite(tflite_path, input_names, output_names)
428428
is_tflite = True
429429
elif tfjs_path is not None:
430-
main_g, subgraphs = graphs_from_tfjs(tfjs_path, input_names, output_names, ignore_default, use_default)
430+
main_g, subgraphs = graphs_from_tfjs(tfjs_path, input_names, output_names, shape_override,
431+
ignore_default, use_default)
431432
else:
432433
main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values,
433434
ignore_default, use_default)

0 commit comments

Comments
 (0)