Skip to content

Commit b15ae91

Browse files
Add conversion of _FusedDepthwiseConv2dNative and Prelu for TFJS (#1628)
* Add conversion of _FusedDepthwiseConv2dNative Signed-off-by: Tom Wildenhain <[email protected]> * Implement PRelu for tfjs Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 412f315 commit b15ae91

File tree

8 files changed

+93
-18
lines changed

8 files changed

+93
-18
lines changed

tests/backend_test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
420420

421421
if test_tfjs:
422422
try:
423-
tfjs_res = run_tfjs(tfjs_path, feed_dict, self.test_data_directory)
423+
tfjs_res = run_tfjs(tfjs_path, feed_dict)
424424
except RuntimeError as e:
425425
ignored_errors = ["is not yet supported", "Operands could not be broadcast together",
426426
"unknown dtype null", "must be [NaN", "Cannot read property 'name' of undefined",

tests/run_pretrained_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def run_tflite():
451451
inputs[k] = self.make_input(v)
452452
if not self.skip_tensorflow:
453453
logger.info("Running TFJS")
454-
tf_results = run_tfjs(tfjs_path, inputs, dir_name)
454+
tf_results = run_tfjs(tfjs_path, inputs, outputs)
455455
logger.info("TFJS OK")
456456

457457
if not self.run_tf_frozen:

tests/run_pretrained_models.yaml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,37 @@ posenet_mobilenet_quantized_2_075_tfjs:
636636
rtol: 0.1
637637
ptol: 0.2
638638
atol: 0.005
639+
640+
blazeposedetector_tfjs:
641+
tf_min_version: 2.1
642+
disabled: false
643+
url: https://tfhub.dev/mediapipe/tfjs-model/blazeposedetector/1/default/1?tfjs-format=compressed
644+
model: "model.json"
645+
opset_constraints:
646+
"onnx":
647+
"min": 10
648+
model_type: tfjs
649+
input_get: get_beach
650+
#force_input_shape: True # ORT doesn't implement autopadding for convs with dilations
651+
inputs:
652+
"input:0": [1, 224, 224, 3]
653+
outputs:
654+
- Identity:0
655+
rtol: 0.05
656+
atol: 0.0005
657+
658+
facemesh_tfjs:
659+
tf_min_version: 2.1
660+
disabled: false
661+
url: https://tfhub.dev/mediapipe/tfjs-model/facemesh/1/default/1?tfjs-format=compressed
662+
model: "model.json"
663+
model_type: tfjs
664+
input_get: get_beach
665+
inputs:
666+
"input_1:0": [1, 192, 192, 3]
667+
outputs:
668+
- Identity:0
669+
- Identity_1:0
670+
- Identity_2:0
671+
rtol: 0.05
672+
atol: 0.0005

tests/run_tfjs.js

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ const http = require('http');
1515
const path = require('path');
1616
const { exit } = require('process');
1717

18-
const [, , modelPath, inputPath, outputPath] = process.argv;
18+
const [, , modelPath, inputPath, outputPath, ...args] = process.argv;
1919

2020
// Hide tfjs first use message complaining about lack of GPU
2121
tf.backend().firstUse = false;
@@ -147,8 +147,11 @@ async function main() {
147147
const inputString = fs.readFileSync(inputPath, 'utf8');
148148
const inputJson = JSON.parse(inputString);
149149
const input = inputFromJson(inputJson);
150-
151-
const output = await model.executeAsync(input);
150+
let outputs = null;
151+
if (args && args[0] == '--outputs') {
152+
outputs = args[1].split(',');
153+
}
154+
const output = await model.executeAsync(input, outputs);
152155

153156
const outputJson = outputToJson(output);
154157
const outputString = JSON.stringify(outputJson);

tests/tfjs_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def json_to_output(obj):
6767
return [json_to_numpy(obj)]
6868

6969

70-
def run_tfjs(tfjs_path, inputs, working_dir):
70+
def run_tfjs(tfjs_path, inputs, outputs=None):
7171
"""
7272
Given the path to the model.json of a tfjs model, a dict mapping input names to numpy arrays, and a working
7373
directory, runs the model on the inputs and returns the resulting arrays or raises a RuntimeException. Calls
@@ -79,11 +79,15 @@ def run_tfjs(tfjs_path, inputs, working_dir):
7979
output_path = os.path.join(working_dir, 'output.json')
8080
stderr_path = os.path.join(working_dir, 'stderr.txt')
8181

82+
command = ['node', script_path, tfjs_path, input_path, output_path]
83+
if outputs is not None:
84+
command.extend(['--outputs', ','.join(outputs)])
85+
8286
with open(input_path, 'wt') as f:
8387
json.dump(inputs_to_json(inputs), f)
8488

8589
with open(stderr_path, 'wb') as f:
86-
result = subprocess.run(['node', script_path, tfjs_path, input_path, output_path], stderr=f, check=False)
90+
result = subprocess.run(command, stderr=f, check=False)
8791
if result.returncode != 0:
8892
with open(stderr_path, 'rt') as f:
8993
err = f.read()

tf2onnx/onnx_opset/math.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ def version_9(cls, ctx, node, **kwargs):
7070
pass
7171

7272

73+
@tf_op(["Prelu"], onnx_op="PRelu")
74+
class Prelu:
75+
@classmethod
76+
def version_1(cls, ctx, node, **kwargs):
77+
pass
78+
79+
7380
def make_min_or_max_op(ctx, op_type, inputs, outputs,
7481
output_shapes=None, output_dtypes=None):
7582
# support more dtype

tf2onnx/rewriter/fused_op_rewriter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
def rewrite_fused_ops(g, ops):
1313
for node in ops:
14-
if node.type in ["_FusedConv2D", "_FusedMatMul"]:
14+
if node.type in ["_FusedConv2D", "_FusedMatMul", "_FusedDepthwiseConv2dNative"]:
1515
op_types = [op.decode() for op in node.get_attr_value("fused_ops")]
1616
extra_inputs = node.input[2:]
1717
g.replace_inputs(node, node.input[:2])
@@ -21,12 +21,14 @@ def rewrite_fused_ops(g, ops):
2121
shape = g.get_shape(node.output[0])
2222
first_node = None
2323
for op in op_types:
24-
new_node = g.make_node(op, [last_output] + extra_inputs, skip_conversion=False,
24+
num_inputs = {"BiasAdd": 2, "FusedBatchNorm": 5}.get(op, 1 + len(extra_inputs))
25+
my_inputs = [last_output] + extra_inputs[:num_inputs - 1]
26+
new_node = g.make_node(op, my_inputs, skip_conversion=False,
2527
op_name_scope=node.name, dtypes=[dtype], shapes=[shape])
2628
last_output = new_node.output[0]
29+
extra_inputs = extra_inputs[num_inputs - 1:]
2730
if first_node is None:
2831
first_node = new_node
29-
extra_inputs = []
3032

3133
consumers = [n for n in g.find_output_consumers(node.output[0]) if n != first_node]
3234
g.replace_all_inputs(node.output[0], last_output, consumers)

tf2onnx/tfjs_utils.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,17 @@ def resolve_output(output, op_info, func_name=None):
100100
long_node_name = func_name + "/" + node
101101
if long_node_name in op_info:
102102
node = long_node_name
103-
op_type, tf_attr = op_info[node]
104-
names, _ = get_output_names_and_dtypes(op_type, tf_attr)
103+
op_type, tf_attr, inp_dtypes = op_info[node]
104+
names, _ = get_output_names_and_dtypes(op_type, tf_attr, inp_dtypes)
105105
idx = names.index(output_arg_name) + int(index)
106106
return node + ':' + str(idx)
107107

108108

109-
def get_output_names_and_dtypes(op_type, tf_attr):
109+
def get_output_names_and_dtypes(op_type, tf_attr, inp_dtypes):
110110
"""Parses the tf documentation to determine the names and dtypes of the outputs of the op"""
111111
# TODO: ['Prelu', 'Conv1D', 'DepthwiseConv2d', 'FusedDepthwiseConv2dNative', 'Ones', 'Zeros']
112+
if op_type == 'Prelu':
113+
return ['activations'], [inp_dtypes[0]]
112114
try:
113115
tf_op_def = tf_api_def_map.get_op_def(op_type)
114116
except ValueError:
@@ -134,15 +136,19 @@ def get_output_names_and_dtypes(op_type, tf_attr):
134136
return names, dtypes
135137

136138

137-
def get_output_dtypes(op_type, tf_attr):
139+
def get_output_dtypes(op_type, tf_attr, inp_dtypes):
138140
"""Returns a list of the tf dtypes for the op's outputs"""
139-
_, out_dtypes = get_output_names_and_dtypes(op_type, tf_attr)
141+
_, out_dtypes = get_output_names_and_dtypes(op_type, tf_attr, inp_dtypes)
140142
return out_dtypes
141143

142144

143145
def get_output_shapes(node_def, input_dtypes, input_shapes, inp_consts):
144146
"""Returns a list of the output shapes of an op. input_dtypes should be tf dtypes."""
145147
from tf2onnx.tf_loader import tf_session, tf_placeholder # pylint: disable=import-outside-toplevel
148+
149+
if node_def.op == "Prelu":
150+
return [input_shapes[0]]
151+
146152
del node_def.input[:]
147153
node_def.name = "node"
148154

@@ -213,6 +219,16 @@ def graphs_from_tfjs(model_path, input_names=None, output_names=None, shape_over
213219
topologically sorted list of subgraphs."""
214220
model, zip_compressed = read_model_json(model_path)
215221

222+
model_format = model['modelTopology'].get('format')
223+
if model_format is None:
224+
if 'keras_version' in model['modelTopology']:
225+
model_format = 'layers-model'
226+
else:
227+
model_format = 'graph-model'
228+
utils.make_sure(model_format == 'graph-model', "tf2onnx only supports conversion from tfjs graph models, "
229+
"not format %s. Use Google's tfjs converter to convert to a graph model, then try again.",
230+
model_format)
231+
216232
weights_manifest = model['weightsManifest'][0]
217233

218234
sharded_data = []
@@ -356,7 +372,7 @@ def update_shapes(new_shapes):
356372
onnx_nodes.append(onnx_node)
357373
output_shapes[out_name] = shape_override.get(out_name, list(np_arr.shape))
358374
tf_dtypes[out_name] = tf_dtype
359-
op_info[node_name] = (op_type, {'dtype': tf_dtypes[out_name]})
375+
op_info[node_name] = (op_type, {'dtype': tf_dtypes[out_name]}, [])
360376
continue
361377
tf_attr = {}
362378
onnx_attr = {}
@@ -368,16 +384,25 @@ def update_shapes(new_shapes):
368384
if k == 'DstT':
369385
k = 'to'
370386
onnx_attr[k] = read_tfjs_attr(v)
371-
op_info[node_name] = (op_type, tf_attr)
387+
if op_type == "FusedDepthwiseConv2dNative":
388+
# This op isn't in tensorflow but can be converted to a TF op
389+
op_type = "_FusedDepthwiseConv2dNative"
390+
err_msg = "explicit_paddings for supported for _FusedDepthwiseConv2dNative"
391+
utils.make_sure(len(tf_attr['explicit_paddings']) == 0, err_msg)
392+
del tf_attr['explicit_paddings']
393+
del onnx_attr['explicit_paddings']
394+
del node_def.attr['explicit_paddings']
395+
node_def.op = op_type
372396

373397
input_names = [inp for inp in node.get('input', []) if not inp.startswith('^')]
374398
input_names = [resolve_output(inp, op_info, func_name) for inp in input_names]
375399
unused_outputs.difference_update(input_names)
376400
inp_dtypes = [tf_dtypes[inp] for inp in input_names]
377401
inp_shapes = [output_shapes[inp] for inp in input_names]
378402
inp_consts = [weights.get(inp.split(':')[0]) for inp in input_names]
379-
out_dtypes = get_output_dtypes(op_type, tf_attr)
403+
out_dtypes = get_output_dtypes(op_type, tf_attr, inp_dtypes)
380404
out_shapes = get_output_shapes(node_def, inp_dtypes, inp_shapes, inp_consts)
405+
op_info[node_name] = (op_type, tf_attr, inp_dtypes)
381406

382407
output_names = [node_name + ":" + str(i) for i in range(len(out_dtypes))]
383408
tf_dtypes.update(zip(output_names, out_dtypes))

0 commit comments

Comments
 (0)