Skip to content

Commit 11983f8

Browse files
Add TFJS conversion (#1614)
Add TFJS conversion
1 parent c5aba9a commit 11983f8

20 files changed

+880
-49
lines changed

ci_build/azure_pipelines/templates/job_generator.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ parameters:
1212
report_coverage: 'False'
1313
artifact_name: 'single_test_coverage'
1414
skip_tflite_tests: 'True'
15+
skip_tfjs_tests: 'True'
1516
skip_tf_tests: 'False'
1617

1718
jobs:
@@ -29,7 +30,7 @@ jobs:
2930
${{ each onnx_backend in parameters.onnx_backends }}:
3031
${{ each onnx_backend_version in onnx_backend.value }}:
3132
${{ each onnx_opset in parameters.onnx_opsets }}:
32-
${{ format('{0} python{1}{2} tf{3} onnx{4} {5}{6}{7}', platform, python_version, replace(replace(parameters.skip_tflite_tests,'True', ''), 'False', ' tflite'), tf_version, onnx_version, replace(format('opset{0} ', onnx_opset), 'opset ', ''), onnx_backend.key, onnx_backend_version) }}:
33+
${{ format('{0} python{1}{2}{3} tf{4} onnx{5} {6}{7}{8}', platform, python_version, replace(replace(parameters.skip_tflite_tests,'True', ''), 'False', ' tflite'), replace(replace(parameters.skip_tfjs_tests,'True', ''), 'False', ' tfjs'), tf_version, onnx_version, replace(format('opset{0} ', onnx_opset), 'opset ', ''), onnx_backend.key, onnx_backend_version) }}:
3334
${{ if eq(platform, 'linux') }}:
3435
CI_VM_IMAGE: 'ubuntu-16.04'
3536
${{ if eq(platform, 'windows') }}:
@@ -46,6 +47,7 @@ jobs:
4647
CI_ONNX_BACKEND_VERSION: '${{ onnx_backend_version }}'
4748
CI_SKIP_TF_TESTS: '${{ parameters.skip_tf_tests }}'
4849
CI_SKIP_TFLITE_TESTS: '${{ parameters.skip_tflite_tests }}'
50+
CI_SKIP_TFJS_TESTS: '${{ parameters.skip_tfjs_tests }}'
4951

5052
${{ if eq(tf_version, '') }}:
5153
CI_PIP_TF_NAME: 'tensorflow'

ci_build/azure_pipelines/templates/setup.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ steps:
2525
pip install --index-url https://test.pypi.org/simple/ ort-nightly
2626
fi
2727
28+
if [[ $CI_SKIP_TFJS_TESTS == "False" ]] ;
29+
then
30+
pip install tensorflowjs
31+
npm install @tensorflow/tfjs
32+
fi
33+
2834
if [[ $CI_TF_VERSION == 2.* ]] ;
2935
then
3036
pip install onnxruntime-extensions

ci_build/azure_pipelines/templates/unit_test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
parameters:
44
onnx_opsets: ['14', '13', '12', '11', '10', '9', '8']
55
skip_tflite_tests: 'True'
6+
skip_tfjs_tests: 'True'
67
skip_tf_tests: 'False'
78

89
steps:
@@ -11,6 +12,7 @@ steps:
1112
export TF2ONNX_TEST_BACKEND=$CI_ONNX_BACKEND
1213
export TF2ONNX_TEST_OPSET=$CI_ONNX_OPSET
1314
export TF2ONNX_SKIP_TFLITE_TESTS=$CI_SKIP_TFLITE_TESTS
15+
export TF2ONNX_SKIP_TFJS_TESTS=$CI_SKIP_TFJS_TESTS
1416
export TF2ONNX_SKIP_TF_TESTS=$CI_SKIP_TF_TESTS
1517
python -m pytest --cov=tf2onnx --cov-report=term --disable-pytest-warnings -r s tests --cov-append
1618
timeoutInMinutes: 15
@@ -19,4 +21,5 @@ steps:
1921
env:
2022
CI_ONNX_OPSET: '${{ onnx_opset }}'
2123
CI_SKIP_TFLITE_TESTS: '${{ parameters.skip_tflite_tests }}'
24+
CI_SKIP_TFJS_TESTS: '${{ parameters.skip_tfjs_tests }}'
2225
CI_SKIP_TF_TESTS: '${{ parameters.skip_tf_tests }}'

ci_build/azure_pipelines/unit_test.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,19 @@
33
stages:
44
- stage:
55
jobs:
6+
- template: 'templates/job_generator.yml'
7+
parameters:
8+
# TFJS tf 2.5
9+
python_versions: ['3.9']
10+
tf_versions: ['2.5.0']
11+
onnx_opsets: ['']
12+
skip_tfjs_tests: 'False'
13+
skip_tf_tests: 'True'
14+
job:
15+
steps:
16+
- template: 'unit_test.yml'
17+
report_coverage: 'True'
18+
619
- template: 'templates/job_generator.yml'
720
parameters:
821
# TFLite tf 2.5

tests/backend_test_base.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tensorflow.python.ops import lookup_ops
1919
import onnx
2020
from common import get_test_config
21+
from tfjs_runner import run_tfjs
2122
from tf2onnx import utils
2223
from tf2onnx.tfonnx import process_tf_graph
2324
from tf2onnx import optimizer
@@ -113,6 +114,8 @@ def assert_results_equal(self, expected, actual, rtol, atol, mtol=None,
113114
decode = np.vectorize(lambda x: x.replace(b'\x00', b'').decode('UTF-8'))
114115
expected_val_str = decode(expected_val)
115116
self.assertAllEqual(expected_val_str, actual_val)
117+
elif expected_val.dtype.kind == 'U':
118+
self.assertAllEqual(expected_val, actual_val)
116119
else:
117120
if mtol is not None:
118121
expected_val = np.minimum(expected_val, mtol)
@@ -189,11 +192,21 @@ def freeze_and_run_tf(self, func, feed_dict, outputs, as_session, premade_placeh
189192
tf.import_graph_def(graph_def, name='')
190193
graph_def = tf_optimize(list(feed_dict.keys()), outputs, graph_def, fold_constant=constant_fold)
191194

192-
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
193-
utils.save_protobuf(model_path, graph_def)
194-
self.logger.debug("created file %s", model_path)
195195
return result, graph_def, initialized_tables
196196

197+
def convert_to_tfjs(self, graph_def_path, output_names):
198+
from tensorflowjs.converters import converter
199+
tfjs_path = os.path.join(self.test_data_directory, self._testMethodName + "_tfjs")
200+
try:
201+
converter.convert([graph_def_path, tfjs_path, '--input_format', 'tf_frozen_model',
202+
'--output_node_names', ','.join(output_names)])
203+
except ValueError:
204+
return None
205+
model_path = os.path.join(tfjs_path, 'model.json')
206+
if not os.path.exists(model_path):
207+
return None
208+
return model_path
209+
197210
def convert_to_tflite(self, graph_def, feed_dict, outputs):
198211
if not feed_dict:
199212
return None # Can't make TFlite model with no inputs
@@ -306,6 +319,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
306319
use_custom_ops=False):
307320
test_tf = not self.config.skip_tf_tests
308321
test_tflite = not self.config.skip_tflite_tests
322+
test_tfjs = not self.config.skip_tfjs_tests
309323
run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test
310324
# optional - passed to process_tf_graph
311325
if process_args is None:
@@ -323,6 +337,15 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
323337
self.freeze_and_run_tf(func, feed_dict, output_names_with_port, as_session,
324338
premade_placeholders, large_model, constant_fold)
325339

340+
graph_def_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
341+
utils.save_protobuf(graph_def_path, graph_def)
342+
self.logger.debug("created file %s", graph_def_path)
343+
344+
if test_tfjs:
345+
tfjs_path = self.convert_to_tfjs(graph_def_path, output_names_with_port)
346+
if tfjs_path is None:
347+
test_tfjs = False
348+
326349
if test_tflite:
327350
tflite_path = self.convert_to_tflite(graph_def, feed_dict, output_names_with_port)
328351
test_tflite = tflite_path is not None and self.tflite_has_supported_types(tflite_path)
@@ -383,8 +406,39 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
383406
if graph_validator:
384407
self.assertTrue(graph_validator(g))
385408

409+
if test_tfjs:
410+
try:
411+
tfjs_res = run_tfjs(tfjs_path, feed_dict, self.test_data_directory)
412+
except RuntimeError as e:
413+
ignored_errors = ["is not yet supported", "Operands could not be broadcast together",
414+
"unknown dtype null", "must be [NaN", "Cannot read property 'name' of undefined",
415+
"Either strides or dilations must be 1", "does not support"]
416+
if any(err in str(e) for err in ignored_errors):
417+
test_tfjs = False
418+
else:
419+
raise e
420+
421+
if test_tfjs:
422+
g = process_tf_graph(None, opset=self.config.opset,
423+
input_names=list(feed_dict.keys()),
424+
output_names=None,
425+
target=self.config.target,
426+
tfjs_path=tfjs_path,
427+
**process_args)
428+
g = optimizer.optimize_graph(g)
429+
onnx_tfjs_res = self.run_backend(g, None, onnx_feed_dict, large_model,
430+
postfix="_from_tfjs", use_custom_ops=use_custom_ops)
431+
432+
self.assert_results_equal(tfjs_res, onnx_tfjs_res, rtol, atol, mtol, check_value, check_shape,
433+
check_dtype=False)
434+
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
435+
436+
if graph_validator:
437+
self.assertTrue(graph_validator(g))
438+
439+
386440
if g is None:
387-
raise unittest.SkipTest("Both tf and tflite marked to skip")
441+
raise unittest.SkipTest("tf, tflite, and tfjs marked to skip")
388442
return g
389443

390444
def save_onnx_model(self, model_proto, feed_dict, postfix="", external_tensor_storage=None):

tests/common.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"check_opset_max_version",
3232
"skip_tf2",
3333
"skip_tflite",
34+
"skip_tfjs",
3435
"requires_tflite",
3536
"check_opset_after_tf_version",
3637
"check_target",
@@ -58,6 +59,7 @@ def __init__(self):
5859
self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',')
5960
self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
6061
self.skip_tflite_tests = os.environ.get("TF2ONNX_SKIP_TFLITE_TESTS", "FALSE").upper() == "TRUE"
62+
self.skip_tfjs_tests = os.environ.get("TF2ONNX_SKIP_TFJS_TESTS", "FALSE").upper() == "TRUE"
6163
self.skip_tf_tests = os.environ.get("TF2ONNX_SKIP_TF_TESTS", "FALSE").upper() == "TRUE"
6264
self.skip_onnx_checker = False
6365
self.allow_missing_shapes = False
@@ -102,6 +104,7 @@ def __str__(self):
102104
"opset={}".format(self.opset),
103105
"target={}".format(self.target),
104106
"skip_tflite_tests={}".format(self.skip_tflite_tests),
107+
"skip_tfjs_tests={}".format(self.skip_tfjs_tests),
105108
"skip_tf_tests={}".format(self.skip_tf_tests),
106109
"run_tfl_consistency_test={}".format(self.run_tfl_consistency_test),
107110
"backend={}".format(self.backend),
@@ -181,12 +184,31 @@ def skip_tf2(message=""):
181184
return unittest.skipIf(tf_loader.is_tf2(), reason)
182185

183186

187+
def skip_tfjs(message=""):
188+
""" Skip the tfjs conversion for this test """
189+
config = get_test_config()
190+
reason = _append_message("test disabled for tfjs", message)
191+
if config.skip_tf_tests and config.skip_tflite_tests:
192+
# If we are skipping tf and tflite also, there is no reason to run this test
193+
return unittest.skip(reason)
194+
def decorator(func):
195+
def test(self):
196+
tmp = config.skip_tfjs_tests
197+
config.skip_tfjs_tests = True
198+
try:
199+
func(self)
200+
finally:
201+
config.skip_tfjs_tests = tmp
202+
return test
203+
return decorator
204+
205+
184206
def skip_tflite(message=""):
185207
""" Skip the tflite conversion for this test """
186208
config = get_test_config()
187209
reason = _append_message("test disabled for tflite", message)
188-
if config.skip_tf_tests:
189-
# If we are skipping tf also, there is no reason to run this test
210+
if config.skip_tf_tests and config.skip_tfjs_tests:
211+
# If we are skipping tf and tfjs also, there is no reason to run this test
190212
return unittest.skip(reason)
191213
def decorator(func):
192214
def test(self):

tests/run_tfjs.js

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
4+
/* Helper script to run tfjs models. Uses custom .json representation to encode tensors for inputs/outputs
5+
* Example usage:
6+
*
7+
* node run_tfjs.js mymodel/model.json input.json output.json
8+
*/
9+
10+
const tf = require('@tensorflow/tfjs');
11+
12+
const fs = require('fs');
13+
const http = require('http');
14+
const path = require('path');
15+
const { exit } = require('process');
16+
17+
const [, , modelPath, inputPath, outputPath] = process.argv;
18+
19+
// Hide tfjs first use message complaining about lack of GPU
20+
tf.backend().firstUse = false;
21+
22+
if (process.argv[2] == '--test') {
23+
// dtype = 'float32'|'int32'|'bool'|'complex64'|'string'
24+
25+
const floatTensor = tf.tensor([1.1, 2.2, 3.3, 4.4], [2, 2], 'float32');
26+
const intTensor = tf.tensor([1, 2, 3, 4], [2, 2], 'int32');
27+
const boolTensor = tf.tensor([true, false, true, true], [2, 2], 'bool');
28+
const complexTensor = tf.complex([1.1, 2.2, 3.3, 4.4], [10., 20., 30., 40.]).reshape([2, 2]);
29+
const stringTensor = tf.tensor(['Hello world', '♦♥♠♣', '', 'Tensors'], [2, 2], 'string');
30+
31+
const tensors = [floatTensor, intTensor, boolTensor, complexTensor, stringTensor];
32+
tensors.forEach(function (tensor) {
33+
const tensorEnc = tensorToJson(tensor);
34+
const tensorDec = tensorFromJson(tensorEnc);
35+
if (tensor.toString() != tensorDec.toString()) {
36+
console.log("Tensor:")
37+
tensor.print()
38+
console.log("Decoded tensor:")
39+
tensorDec.print()
40+
throw "Test failure"
41+
}
42+
});
43+
44+
console.log("All tests pass.")
45+
exit(0)
46+
}
47+
48+
const modelDir = path.dirname(modelPath);
49+
const modelName = path.basename(modelPath);
50+
51+
// tf.loadGraphModel expects a url not a local file, so we serve it on localhost
52+
http.createServer(function (req, res) {
53+
fs.readFile(modelDir + req.url, function (err, data) {
54+
if (err) {
55+
res.writeHead(404);
56+
res.end(JSON.stringify(err));
57+
return;
58+
}
59+
res.writeHead(200);
60+
res.end(data);
61+
});
62+
}).listen(8080);
63+
64+
function tensorToJson(tensor) {
65+
if (tensor.dtype != 'string') {
66+
const data = tensor.dataSync()
67+
const byteArray = new Uint8Array(data.buffer, data.byteOffset, data.byteLength)
68+
const dataEnc = Buffer.from(byteArray).toString('base64');
69+
return {
70+
dataEnc: dataEnc,
71+
shape: tensor.shape,
72+
dtype: tensor.dtype
73+
};
74+
}
75+
76+
return {
77+
data: tensor.dataSync(),
78+
shape: tensor.shape,
79+
dtype: tensor.dtype
80+
};
81+
}
82+
83+
function getTypedArrayConstructor(dtype) {
84+
if (dtype == 'complex64') {
85+
return Float32Array;
86+
}
87+
return tf.util.getTypedArrayFromDType(dtype).constructor;
88+
}
89+
90+
function tensorFromJson(json) {
91+
let data = json.data;
92+
if (data == undefined) {
93+
const arrayType = getTypedArrayConstructor(json.dtype);
94+
data = new arrayType(new Uint8Array(Buffer.from(json.dataEnc, 'base64')).buffer);
95+
}
96+
if (json.dtype == 'complex64') {
97+
const floatTensor = tf.tensor(data, [data.length], 'float32').reshape([-1, 2]).transpose();
98+
return tf.complex(floatTensor.gather(0), floatTensor.gather(1)).reshape(json.shape);
99+
}
100+
return tf.tensor(data, json.shape, json.dtype);
101+
}
102+
103+
function inputFromJson(json) {
104+
// Input can be a tensor, list of tensors, or mapping of input names to tensors.
105+
if (Array.isArray(json)) {
106+
return json.map(tensorFromJson);
107+
}
108+
const keys = Object.keys(json);
109+
if (keys.length == 0 || json[keys[0]].dtype != undefined) {
110+
const result = {};
111+
keys.forEach(k => { result[k] = tensorFromJson(json[k]); });
112+
return result;
113+
}
114+
return tensorFromJson(json);
115+
}
116+
117+
function outputToJson(out) {
118+
// Output can be a tensor, list of tensors, or mapping of output names to tensors.
119+
if (Array.isArray(out)) {
120+
return out.map(tensorToJson);
121+
}
122+
if (out instanceof tf.Tensor) {
123+
return tensorToJson(out);
124+
}
125+
const result = {};
126+
Object.keys(out).forEach(k => { result[k] = tensorToJson(out[k]); });
127+
return result;
128+
}
129+
130+
async function main() {
131+
const model = await tf.loadGraphModel('http://localhost:8080/' + modelName);
132+
const inputString = fs.readFileSync(inputPath, 'utf8');
133+
const inputJson = JSON.parse(inputString);
134+
const input = inputFromJson(inputJson);
135+
136+
const output = await model.executeAsync(input);
137+
138+
const outputJson = outputToJson(output);
139+
const outputString = JSON.stringify(outputJson);
140+
fs.writeFileSync(outputPath, outputString, 'utf8');
141+
}
142+
143+
main().then(() => exit(0)).catch((err) => { console.error(err); exit(1) })

0 commit comments

Comments
 (0)